diff --git a/.gitignore b/.gitignore index 539f959076..f883af441e 100644 --- a/.gitignore +++ b/.gitignore @@ -31,6 +31,7 @@ config.mk bli_config.h +bli_addon.h # -- monolithic headers -- diff --git a/CMakeLists.txt b/CMakeLists.txt index 0483435679..2752df7e68 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -17,7 +17,7 @@ set(AOCL_BLIS_ZEN TRUE) set (PYTHON_EXE "python") if ("${AOCL_BLIS_FAMILY}" STREQUAL "") - message(FATAL_ERROR "Machine configuration missing! Select one of zen, zen2, zen3 or amdzen") + message(FATAL_ERROR "Machine configuration missing! Select one of zen, zen2, zen3, zen4 or amdzen") endif () if (${AOCL_BLIS_FAMILY} STREQUAL "auto") @@ -50,20 +50,32 @@ elseif (${AOCL_BLIS_FAMILY} STREQUAL "zen3") add_definitions(-DBLIS_KERNELS_ZEN2) add_definitions(-DBLIS_KERNELS_ZEN) add_definitions(-DBLIS_KERNELS_HASWELL) +elseif (${AOCL_BLIS_FAMILY} STREQUAL "zen4") + add_definitions(-DBLIS_FAMILY_ZEN4) + add_definitions(-DBLIS_CONFIG_ZEN4) + add_definitions(-DBLIS_KERNELS_SKX) + add_definitions(-DBLIS_KERNELS_ZEN4) + add_definitions(-DBLIS_KERNELS_ZEN3) + add_definitions(-DBLIS_KERNELS_ZEN2) + add_definitions(-DBLIS_KERNELS_ZEN) + add_definitions(-DBLIS_KERNELS_HASWELL) elseif (${AOCL_BLIS_FAMILY} STREQUAL "amdzen") set(AOCL_BLIS_ZEN FALSE) add_definitions(-DBLIS_FAMILY_AMDZEN) + add_definitions(-DBLIS_CONFIG_ZEN4) add_definitions(-DBLIS_CONFIG_ZEN3) add_definitions(-DBLIS_CONFIG_ZEN2) add_definitions(-DBLIS_CONFIG_ZEN) add_definitions(-DBLIS_CONFIG_GENERIC) + add_definitions(-DBLIS_KERNELS_SKX) + add_definitions(-DBLIS_KERNELS_ZEN4) add_definitions(-DBLIS_KERNELS_ZEN3) add_definitions(-DBLIS_KERNELS_ZEN2) add_definitions(-DBLIS_KERNELS_HASWELL) add_definitions(-DBLIS_KERNELS_ZEN) add_definitions(-DBLIS_KERNELS_GENERIC) else () - message(FATAL_ERROR "Wrong machine configuration. Select one of zen, zen2, zen3 or amdzen") + message(FATAL_ERROR "Wrong machine configuration. Select one of zen, zen2, zen3, zen4 or amdzen") endif () set(TARGET_ARCH ${AOCL_BLIS_FAMILY}) @@ -95,6 +107,8 @@ option (ENABLE_UPPERCASE_API "export APIs with uppercase" OFF) option (ENABLE_COMPLEX_RETURN_INTEL "Enable complex_return_intel" OFF) option (ENABLE_TRSM_PREINVERSION "Enable TRSM preinversion" ON) option (ENABLE_AOCL_DYNAMIC "Enable Dynamic Multi-threading" OFF) +option(DISABLE_BLIS_ARCH_TYPE "Disable BLIS_ARCH_TYPE functionality" OFF) +option(RENAME_BLIS_ARCH_TYPE "Rename BLIS_ARCH_TYPE env var renamed to supplied value" BLIS_ARCH_TYPE) if (${AOCL_BLIS_FAMILY} STREQUAL "amdzen") set(REF_KERNEL_MIRRORING_PY "${CMAKE_SOURCE_DIR}/build/blis_ref_kernel_mirror.py") @@ -270,6 +284,21 @@ else() endif() endif() +if(DISABLE_BLIS_ARCH_TYPE) + set(BLIS_DISABLE_BLIS_ARCH_TYPE TRUE) +else() + set(BLIS_DISABLE_BLIS_ARCH_TYPE FALSE) +endif() + +if(RENAME_BLIS_ARCH_TYPE) + set(__blis_arch_type_name TRUE) + set(rename_blis_arch_type "${RENAME_BLIS_ARCH_TYPE}") +else() + set(__blis_arch_type_name TRUE) + set(rename_blis_arch_type "BLIS_ARCH_TYPE") +endif() + + #print configurations message("---cmake configurations---") message(CMAKE_C_COMPILER_ID : ${CMAKE_C_COMPILER_ID}) @@ -291,7 +320,8 @@ message(BLIS_ENABLE_MEMKIND : ${BLIS_ENABLE_MEMKIND}) message(BLIS_ENABLE_PRAGMA_OMP_SIMD : ${BLIS_ENABLE_PRAGMA_OMP_SIMD}) message(BLIS_ENABLE_SANDBOX : ${BLIS_ENABLE_SANDBOX}) message(BLIS_ENABLE_SHARED : ${BLIS_ENABLE_SHARED}) - +message(DISABLE_BLIS_ARCH_TYPE : ${DISABLE_BLIS_ARCH_TYPE}) +message(RENAME_BLIS_ARCH_TYPE : ${RENAME_BLIS_ARCH_TYPE}) SET(ENABLE_SIMD_FLAGS "AVX2" CACHE STRING "Set compiler SIMD flags") SET_PROPERTY(CACHE ENABLE_SIMD_FLAGS PROPERTY STRINGS none SSE2 AVX AVX2) @@ -304,6 +334,15 @@ elseif(${ENABLE_SIMD_FLAGS} MATCHES "SSE2") add_definitions(/arch:SSE2) endif() +if(${TARGET_ARCH} STREQUAL zen4 OR + ${TARGET_ARCH} STREQUAL amdzen) + set_source_files_properties(${CMAKE_CURRENT_SOURCE_DIR}/kernels/zen4/1/bli_amaxv_zen_int_avx512.c PROPERTIES COMPILE_FLAGS /arch:AVX512) + set_source_files_properties(${CMAKE_CURRENT_SOURCE_DIR}/kernels/zen4/3/bli_gemmtrsm_l_zen_16x14.c PROPERTIES COMPILE_FLAGS /arch:AVX512) + set_source_files_properties(${CMAKE_CURRENT_SOURCE_DIR}/kernels/zen4/3/bli_gemmtrsm_u_zen_16x14.c PROPERTIES COMPILE_FLAGS /arch:AVX512) + set_source_files_properties(${CMAKE_CURRENT_SOURCE_DIR}/kernels/skx/3/bli_dgemm_skx_asm_16x14.c PROPERTIES COMPILE_FLAGS /arch:AVX512) + set_source_files_properties(${CMAKE_CURRENT_SOURCE_DIR}/kernels/skx/3/bli_sgemm_skx_asm_32x12_l2.c PROPERTIES COMPILE_FLAGS /arch:AVX512) +endif() + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} /W0 ") set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} /Oi") set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} /MP") @@ -390,11 +429,13 @@ include_directories(${CMAKE_SOURCE_DIR}/config/generic) include_directories(${CMAKE_SOURCE_DIR}/config/zen) include_directories(${CMAKE_SOURCE_DIR}/config/zen2) include_directories(${CMAKE_SOURCE_DIR}/config/zen3) +include_directories(${CMAKE_SOURCE_DIR}/config/zen4) if(${AOCL_BLIS_FAMILY} STREQUAL "amdzen") include_directories(${CMAKE_BINARY_DIR}/ref_kernels/generic) include_directories(${CMAKE_BINARY_DIR}/ref_kernels/zen) include_directories(${CMAKE_BINARY_DIR}/ref_kernels/zen2) include_directories(${CMAKE_BINARY_DIR}/ref_kernels/zen3) + include_directories(${CMAKE_BINARY_DIR}/ref_kernels/zen4) endif() include_directories(${CMAKE_SOURCE_DIR}/ref_kernels) include_directories(${CMAKE_SOURCE_DIR}/kernels) @@ -410,7 +451,9 @@ include_directories(${CMAKE_SOURCE_DIR}/kernels/zen/2) include_directories(${CMAKE_SOURCE_DIR}/kernels/zen/3) include_directories(${CMAKE_SOURCE_DIR}/kernels/zen/3/sup) include_directories(${CMAKE_SOURCE_DIR}/kernels/zen2) - +include_directories(${CMAKE_SOURCE_DIR}/kernels/zen4) +include_directories(${CMAKE_SOURCE_DIR}/kernels/skx) +include_directories(${CMAKE_SOURCE_DIR}/kernels/skx/3) file(GLOB headers ${CMAKE_SOURCE_DIR}/*.h) # Monolithic Header generation @@ -429,6 +472,7 @@ elseif (${AOCL_BLIS_FAMILY} STREQUAL "zen2") " ${CMAKE_CURRENT_SOURCE_DIR}/config/zen/" " ${CMAKE_CURRENT_SOURCE_DIR}/config/zen2/" " ${CMAKE_CURRENT_SOURCE_DIR}/config/zen3/" + " ${CMAKE_CURRENT_SOURCE_DIR}/config/zen4/" " ${CMAKE_CURRENT_SOURCE_DIR}/config/generic/" " ${CMAKE_CURRENT_SOURCE_DIR}/kernels/zen/" " ${CMAKE_CURRENT_SOURCE_DIR}/kernels/haswell/" @@ -541,7 +585,8 @@ message( STATUS "Generating monolithic cblas header file :" ${CMD_OUTPUT}) # setting the blis version string file (STRINGS "version" BLIS_VERSION) set(BLIS_VERSION_STRING ${BLIS_VERSION}) -add_definitions(-DBLIS_VERSION_STRING="AOCL BLIS ${BLIS_VERSION_STRING}") +string(TIMESTAMP BUILD_DATE "%Y%m%d") +add_definitions(-DBLIS_VERSION_STRING="AOCL-BLIS ${BLIS_VERSION_STRING} Build ${BUILD_DATE}") if(BUILD_SHARED_LIBS) add_library("${PROJECT_NAME}" SHARED ${CMAKE_SOURCE_DIR}/bli_config.h @@ -574,6 +619,7 @@ add_subdirectory(frame) add_subdirectory(aocl_dtl) add_subdirectory(test) add_subdirectory(testsuite) +add_subdirectory(bench) if(ENABLE_TESTCPP_TESTING) add_subdirectory(vendor/testcpp) endif() diff --git a/CREDITS b/CREDITS index c6d5d7151a..fd0bcb5b32 100644 --- a/CREDITS +++ b/CREDITS @@ -23,6 +23,7 @@ but many others have contributed code and feedback, including Dilyn Corner @dilyn-corner Mat Cross @matcross (NAG) @decandia50 + Daniƫl de Kok @danieldk (Explosion) Kay Dewhurst @jkd2016 (Max Planck Institute, Halle, Germany) Jeff Diamond (Oracle) Johannes Dieterich @iotamudelta @@ -45,6 +46,7 @@ but many others have contributed code and feedback, including Matthew Honnibal @honnibal Stefan Husmann @stefanhusmann Francisco Igual @figual (Universidad Complutense de Madrid) + Madeesh Kannan @shadeMe Tony Kelman @tkelman Lee Killough @leekillough (Cray) Mike Kistler @mkistler (IBM, Austin Research Laboratory) diff --git a/Makefile b/Makefile index 1658e16de2..1f86acc7e5 100644 --- a/Makefile +++ b/Makefile @@ -116,6 +116,7 @@ BASE_OBJ_FRAME_PATH := $(BASE_OBJ_PATH)/$(FRAME_DIR) BASE_OBJ_AOCLDTL_PATH := $(BASE_OBJ_PATH)/$(AOCLDTL_DIR) BASE_OBJ_REFKERN_PATH := $(BASE_OBJ_PATH)/$(REFKERN_DIR) BASE_OBJ_KERNELS_PATH := $(BASE_OBJ_PATH)/$(KERNELS_DIR) +BASE_OBJ_ADDON_PATH := $(BASE_OBJ_PATH)/$(ADDON_DIR) BASE_OBJ_SANDBOX_PATH := $(BASE_OBJ_PATH)/$(SANDBOX_DIR) # --- Define install target names for static libraries --- @@ -212,6 +213,20 @@ MK_REFKERN_OBJS := $(foreach arch, $(CONFIG_LIST), \ # Generate object file paths for all of the portable framework source code. MK_FRAME_OBJS := $(call gen-obj-paths-from-src,$(FRAME_SRC_SUFS),$(MK_FRAME_SRC),$(FRAME_PATH),$(BASE_OBJ_FRAME_PATH)) +# Generate object file paths for the addon source code. If one or more addons +# were not enabled a configure-time, these variable will we empty. +# NOTE: We separate the source and objects into kernel and non-kernel lists. +MK_ADDON_KERS_SRC := $(foreach addon, $(ADDON_LIST), \ + $(filter $(ADDON_PATH)/$(addon)/$(KERNELS_DIR)/%, \ + $(MK_ADDON_SRC)) \ + ) +MK_ADDON_OTHER_SRC := $(foreach addon, $(ADDON_LIST), \ + $(filter-out $(ADDON_PATH)/$(addon)/$(KERNELS_DIR)/%, \ + $(MK_ADDON_SRC)) \ + ) +MK_ADDON_KERS_OBJS := $(call gen-obj-paths-from-src,$(ADDON_SRC_SUFS),$(MK_ADDON_KERS_SRC),$(ADDON_PATH),$(BASE_OBJ_ADDON_PATH)) +MK_ADDON_OTHER_OBJS := $(call gen-obj-paths-from-src,$(ADDON_SRC_SUFS),$(MK_ADDON_OTHER_SRC),$(ADDON_PATH),$(BASE_OBJ_ADDON_PATH)) +MK_ADDON_OBJS := $(MK_ADDON_KERS_OBJS) $(MK_ADDON_OTHER_OBJS) # AMD has optimized some of the framework files, these optimizations # may not be compatible with other platforms. # @@ -236,8 +251,6 @@ endif # Generate object file paths for all of the debgu and trace logger. MK_AOCLDTL_OBJS := $(call gen-obj-paths-from-src,$(AOCLDTL_SRC_SUFS),$(MK_AOCLDTL_SRC),$(AOCLDTL_PATH),$(BASE_OBJ_AOCLDTL_PATH)) - - # Generate object file paths for the sandbox source code. If a sandbox was not # enabled a configure-time, this variable will we empty. MK_SANDBOX_OBJS := $(call gen-obj-paths-from-src,$(SANDBOX_SRC_SUFS),$(MK_SANDBOX_SRC),$(SANDBOX_PATH),$(BASE_OBJ_SANDBOX_PATH)) @@ -248,6 +261,7 @@ MK_BLIS_OBJS := $(MK_CONFIG_OBJS) \ $(MK_REFKERN_OBJS) \ $(MK_FRAME_OBJS) \ $(MK_AOCLDTL_OBJS) \ + $(MK_ADDON_OBJS) \ $(MK_SANDBOX_OBJS) # Optionally filter out the BLAS and CBLAS compatibility layer object files. @@ -590,6 +604,47 @@ endef # first argument: a configuration name from the union of config_list and # config_name, used to look up the CFLAGS to use during compilation. +# second argument: the C99 addon file suffix being considered. +define make-c99-addon-rule +$(BASE_OBJ_ADDON_PATH)/%.o: $(ADDON_PATH)/%.$(2) $(BLIS_H_FLAT) $(ADDON_H99_FILES) $(MAKE_DEFS_MK_PATHS) +ifeq ($(ENABLE_VERBOSE),yes) + $(CC) $(call get-addon-c99flags-for,$(1)) -c $$< -o $$@ +else + @echo "Compiling $$@" $(call get-addon-c99text-for,$(1)) + @$(CC) $(call get-addon-c99flags-for,$(1)) -c $$< -o $$@ +endif +endef + +# first argument: a configuration name from the union of config_list and +# config_name, used to look up the CFLAGS to use during compilation. +# second argument: the C99 addon file suffix being considered. +# third argument: the name of the addon being considered. +define make-c99-addon-kers-rule +$(BASE_OBJ_ADDON_PATH)/$(3)/$(KERNELS_DIR)/%.o: $(ADDON_PATH)/$(3)/$(KERNELS_DIR)/%.$(2) $(BLIS_H_FLAT) $(ADDON_H99_FILES) $(MAKE_DEFS_MK_PATHS) +ifeq ($(ENABLE_VERBOSE),yes) + $(CC) $(call get-addon-kernel-c99flags-for,$(1)) -c $$< -o $$@ +else + @echo "Compiling $$@" $(call get-addon-kernel-text-for,$(1)) + @$(CC) $(call get-addon-kernel-c99flags-for,$(1)) -c $$< -o $$@ +endif +endef + +# first argument: a configuration name from the union of config_list and +# config_name, used to look up the CFLAGS to use during compilation. +# second argument: the C++ addon file suffix being considered. +define make-cxx-addon-rule +$(BASE_OBJ_ADDON_PATH)/%.o: $(ADDON_PATH)/%.$(2) $(BLIS_H_FLAT) $(ADDON_HXX_FILES) $(MAKE_DEFS_MK_PATHS) +ifeq ($(ENABLE_VERBOSE),yes) + $(CXX) $(call get-addon-cxxflags-for,$(1)) -c $$< -o $$@ +else + @echo "Compiling $$@" $(call get-addon-cxxtext-for,$(1)) + @$(CXX) $(call get-addon-cxxflags-for,$(1)) -c $$< -o $$@ +endif +endef + +# first argument: a configuration name from the union of config_list and +# config_name, used to look up the CFLAGS to use during compilation. +# second argument: the C99 sandbox file suffix being considered. define make-c99-sandbox-rule $(BASE_OBJ_SANDBOX_PATH)/%.o: $(SANDBOX_PATH)/%.$(2) $(BLIS_H_FLAT) $(SANDBOX_H99_FILES) $(MAKE_DEFS_MK_PATHS) ifeq ($(ENABLE_VERBOSE),yes) @@ -600,6 +655,9 @@ else endif endef +# first argument: a configuration name from the union of config_list and +# config_name, used to look up the CFLAGS to use during compilation. +# second argument: the C++ sandbox file suffix being considered. define make-cxx-sandbox-rule $(BASE_OBJ_SANDBOX_PATH)/%.o: $(SANDBOX_PATH)/%.$(2) $(BLIS_H_FLAT) $(SANDBOX_HXX_FILES) $(MAKE_DEFS_MK_PATHS) ifeq ($(ENABLE_VERBOSE),yes) @@ -648,6 +706,22 @@ $(foreach conf, $(CONFIG_LIST), $(eval $(call make-refkern-rule,$(conf)))) $(foreach suf, $(KERNELS_SRC_SUFS), \ $(foreach kset, $(KERNEL_LIST), $(eval $(call make-kernels-rule,$(kset),$(call get-config-for-kset,$(kset)),$(suf))))) +# Instantiate the build rule for C addon files. Use the CFLAGS for the +# configuration family. +$(foreach suf, $(ADDON_C99_SUFS), \ +$(foreach conf, $(CONFIG_NAME), $(eval $(call make-c99-addon-rule,$(conf),$(suf))))) + +# Instantiate the build rule for C addon/kernels files. Use the CFLAGS for the +# configuration family. +$(foreach addon, $(ADDON_LIST), \ +$(foreach suf, $(ADDON_C99_SUFS), \ +$(foreach conf, $(CONFIG_NAME), $(eval $(call make-c99-addon-kers-rule,$(conf),$(suf),$(addon)))))) + +# Instantiate the build rule for C++ addon files. Use the CFLAGS for the +# configuration family. +$(foreach suf, $(ADDON_CXX_SUFS), \ +$(foreach conf, $(CONFIG_NAME), $(eval $(call make-cxx-addon-rule,$(conf),$(suf))))) + # Instantiate the build rule for C sandbox files. Use the CFLAGS for the # configuration family. $(foreach suf, $(SANDBOX_C99_SUFS), \ @@ -1141,6 +1215,9 @@ ifeq ($(ENABLE_VERBOSE),yes) - $(FIND) $(AOCLDTL_FRAG_PATH) -name "$(FRAGMENT_MK)" | $(XARGS) $(RM_F) - $(FIND) $(REFKERN_FRAG_PATH) -name "$(FRAGMENT_MK)" | $(XARGS) $(RM_F) - $(FIND) $(KERNELS_FRAG_PATH) -name "$(FRAGMENT_MK)" | $(XARGS) $(RM_F) +ifneq ($(ADDON_LIST),) + - $(FIND) $(ADDON_FRAG_PATH) -name "$(FRAGMENT_MK)" | $(XARGS) $(RM_F) +endif ifneq ($(SANDBOX),) - $(FIND) $(SANDBOX_FRAG_PATH) -name "$(FRAGMENT_MK)" | $(XARGS) $(RM_F) endif @@ -1155,6 +1232,10 @@ else @- $(FIND) $(REFKERN_FRAG_PATH) -name "$(FRAGMENT_MK)" | $(XARGS) $(RM_F) @echo "Removing makefile fragments from $(KERNELS_FRAG_PATH)" @- $(FIND) $(KERNELS_FRAG_PATH) -name "$(FRAGMENT_MK)" | $(XARGS) $(RM_F) +ifneq ($(ADDON_LIST),) + @echo "Removing makefile fragments from $(ADDON_FRAG_PATH)" + @- $(FIND) $(ADDON_FRAG_PATH) -name "$(FRAGMENT_MK)" | $(XARGS) $(RM_F) +endif ifneq ($(SANDBOX),) @echo "Removing makefile fragments from $(SANDBOX_FRAG_PATH)" @- $(FIND) $(SANDBOX_FRAG_PATH) -name "$(FRAGMENT_MK)" | $(XARGS) $(RM_F) @@ -1275,6 +1356,7 @@ endif # IS_CONFIGURED distclean: cleanmk cleanh cleanlib cleantest ifeq ($(IS_CONFIGURED),yes) ifeq ($(ENABLE_VERBOSE),yes) + - $(RM_F) $(BLIS_ADDON_H) - $(RM_F) $(BLIS_CONFIG_H) - $(RM_F) $(CONFIG_MK_FILE) - $(RM_F) $(PC_OUT_FILE) @@ -1282,6 +1364,8 @@ ifeq ($(ENABLE_VERBOSE),yes) - $(RM_RF) $(LIB_DIR) - $(RM_RF) $(INCLUDE_DIR) else + @echo "Removing $(BLIS_ADDON_H)" + @$(RM_F) $(BLIS_ADDON_H) @echo "Removing $(BLIS_CONFIG_H)" @$(RM_F) $(BLIS_CONFIG_H) @echo "Removing $(CONFIG_MK_FILE)" diff --git a/addon/aocl_gemm/aocl_bf16_type.h b/addon/aocl_gemm/aocl_bf16_type.h new file mode 100644 index 0000000000..f8b2fd431a --- /dev/null +++ b/addon/aocl_gemm/aocl_bf16_type.h @@ -0,0 +1,36 @@ + +/* + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*/ +#ifndef AOCL_GEMM_HALF_PRECISION_TYPE_H +#define AOCL_GEMM_HALF_PRECISION_TYPE_H + +typedef int16_t bfloat16; + +#endif // AOCL_GEMM_HALF_PRECISION_TYPE_H + diff --git a/addon/aocl_gemm/aocl_gemm.h b/addon/aocl_gemm/aocl_gemm.h new file mode 100644 index 0000000000..4e971d932a --- /dev/null +++ b/addon/aocl_gemm/aocl_gemm.h @@ -0,0 +1,41 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef BLIS_ADDON_LPGEMM +#define BLIS_ADDON_LPGEMM + +#include "aocl_gemm_post_ops.h" +#include "aocl_gemm_interface_apis.h" + +#endif // BLIS_ADDON_LPGEMM diff --git a/addon/aocl_gemm/aocl_gemm_bf16_utils.c b/addon/aocl_gemm/aocl_gemm_bf16_utils.c new file mode 100644 index 0000000000..7af08b751b --- /dev/null +++ b/addon/aocl_gemm/aocl_gemm_bf16_utils.c @@ -0,0 +1,126 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" +#include "aocl_gemm_interface_apis.h" +#include "lpgemm_types.h" +#include "lpgemm_config.h" +#include "lpgemm_utils.h" +#include "lpgemm_reorder_bf16.h" + +AOCL_GEMM_GET_REORDER_BUF_SIZE(bf16bf16f32of32) +{ + if ( ( k <= 0 ) || ( n <= 0 ) ) + { + return 0; // Error. + } + + // Check if avx512_bf16 ISA is supported, lpgemm matmul only works with it. + if ( bli_cpuid_is_avx512_bf16_supported() == FALSE ) + { + printf(" AVX512_BF16 ISA not supported by processor, cannot perform lpgemm.\n"); + return 0; // Error. + } + + /* Initialize BLIS. */ + bli_init_auto(); + + // Set MC, NC, KC, NR, MR. + aocl_lpgemm_init_global_cntx(); + + AOCL_MATRIX_TYPE input_mat_type; + bli_param_map_char_to_lpmat_type( mat_type, &input_mat_type ); + + if ( input_mat_type == A_MATRIX ) + { + return 0; // A reorder not supported. + } + + // Extra space since packing does width in multiples of 16. The bf16 + // instruction can be used as long as atleast one zmm register can be fully + // loaded; and since k_dim needs to be atleast 2, having n_dim atleast 16 + // should give 2x16=32 elements, enough for 1 zmm register.The padding is + // not rounded to NR (=64), since that would result in memory wastage. + dim_t n_reorder = make_multiple_of_n( n, 16 ); + + // Extra space since packing does length in multiples of 2. + dim_t k_reorder = make_multiple_of_n( k, 2 ); + + siz_t size_req = sizeof( int16_t ) * k_reorder * n_reorder; + + return size_req; +} + +AOCL_GEMM_REORDER(bfloat16, bf16bf16f32of32) +{ + if ( ( input_buf_addr == NULL ) || ( reorder_buf_addr == NULL ) || + ( k <= 0 ) || ( n <= 0 ) || ( ldb < n ) ) + { + return; // Error. + } + + // Check if avx512_bf16 ISA is supported, lpgemm matmul only works with it. + if ( bli_cpuid_is_avx512_bf16_supported() == FALSE ) + { + printf(" AVX512_BF16 ISA not supported by processor, cannot perform lpgemm.\n"); + return; // Error. + } + + /* Initialize BLIS. */ + bli_init_auto(); + + // Set MC, NC, KC, NR, MR. + aocl_lpgemm_init_global_cntx(); + + AOCL_MATRIX_TYPE input_mat_type; + bli_param_map_char_to_lpmat_type( mat_type, &input_mat_type ); + + if ( input_mat_type == A_MATRIX ) + { + return; // A reorder not supported. + } + + // Create dummy b_reorder obj. + lpgemm_obj_t b_reorder; + b_reorder.storage.aligned_buffer = reorder_buf_addr; + + // Create dummy original b obj; + lpgemm_obj_t b; + b.storage.aligned_buffer = ( void* )input_buf_addr; + b.rs = ldb; + b.width = n; + b.length = k; + + reorderb_nr64_bf16bf16f32of32( &b, &b_reorder ); +} diff --git a/addon/aocl_gemm/aocl_gemm_bf16bf16f32obf16.c b/addon/aocl_gemm/aocl_gemm_bf16bf16f32obf16.c new file mode 100644 index 0000000000..fedf3a43c5 --- /dev/null +++ b/addon/aocl_gemm/aocl_gemm_bf16bf16f32obf16.c @@ -0,0 +1,218 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" +#include "aocl_gemm_interface_apis.h" +#include "lpgemm_types.h" +#include "lpgemm_post_ops.h" +#include "lpgemm_thread_decor_openmp.h" +#include "lpgemm_5loop_interface_apis.h" +#include "lpgemm_config.h" +#include "lpgemm_utils.h" + +AOCL_GEMM_MATMUL(bfloat16,bfloat16,bfloat16,float,bf16bf16f32obf16) +{ + trans_t blis_transa; + trans_t blis_transb; + + // Check if avx512_vnni ISA is supported, lpgemm matmul only works with it. + if ( bli_cpuid_is_avx512_bf16_supported() == FALSE ) + { + printf(" AVX512_BF16 ISA not supported by processor, cannot perform lpgemm.\n"); + return; // Error. + } + + /* Initialize BLIS. */ + bli_init_auto(); + + // Set MC, NC, KC, NR, MR. + aocl_lpgemm_init_global_cntx(); + + // Null check for pointers. + if ( ( a == NULL ) || ( b == NULL ) || ( c == NULL ) ) + { + return; // Error. + } + + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ + bli_param_map_netlib_to_blis_trans( transa, &blis_transa ); + bli_param_map_netlib_to_blis_trans( transb, &blis_transb ); + + /* Perform BLAS parameter checking. */ + // Transpose not supported. + if ( ( blis_transa != BLIS_NO_TRANSPOSE ) || + ( blis_transb != BLIS_NO_TRANSPOSE ) ) + { + return; // Error. + } + + // Sanitize order input. + char order_use = + ( ( order == 'r' ) || ( order == 'R' ) || + ( order == 'c' ) || ( order == 'C' ) ) ? + order : 'r'; + + bool is_row_major = ( ( order_use == 'r' ) || ( order_use == 'R' ) ); + bool is_column_major = ( ( order_use == 'c' ) || ( order_use == 'C' ) ); + + // Row major input expected with leading dimensions >= row stride. + if ( ( is_row_major == TRUE ) && + ( ( lda < k ) || ( ldb < n ) || ( ldc < n ) ) ) + { + return; // Error. + } + // Column major input expected with leading dimensions >= column stride. + else if ( ( is_column_major == TRUE ) && + ( ( lda < m ) || ( ldb < k ) || ( ldc < m ) ) ) + { + return; // Error. + } + + // Check if dimensions are valid. + if ( ( m <= 0) || ( n <= 0 ) || ( k <= 0 ) || + ( lda <= 0 ) || ( ldb <= 0 ) || ( ldc <= 0 ) ) + { + return; // Error. + } + + const inc_t rs_a = lda; + const inc_t cs_a = 1; + const inc_t rs_b = ldb; + const inc_t cs_b = 1; + const inc_t rs_c = ldc; + const inc_t cs_c = 1; + + AOCL_MEMORY_TAG mtag_a; + AOCL_MEMORY_TAG mtag_b; + + bli_param_map_char_to_lpmtag( mem_format_a, &mtag_a ); + bli_param_map_char_to_lpmtag( mem_format_b, &mtag_b ); + + // B matrix needs to be packed in a certain format in order to be loaded + // and used in bf16 instrution. As such the mtag_b always needs to be either + // packed or reordered. B matrix as it is (unpacked) cannot be used, and + // the mtag_b is set to packed to enable runtime packing. + if ( ( is_row_major == TRUE ) && ( mtag_b == UNPACKED ) ) + { + mtag_b = PACK; + } + // Inputs swapped in column major, A becomes B from kernel point of view. + else if ( ( is_column_major == TRUE ) && ( mtag_a == UNPACKED ) ) + { + mtag_a = PACK; + } + + // Only unpacked A supported now. + if ( ( is_row_major == TRUE ) && ( mtag_a != UNPACKED ) ) + { + return; // Error. + } + // Inputs swapped in column major, B becomes A from kernel point of view. + else if ( ( is_column_major == TRUE ) && ( mtag_b != UNPACKED ) ) + { + return; // Error. + } + + // Convert post op struct to post op linked list format. + lpgemm_post_op post_op_list[AOCL_MAX_POST_OPS]; + lpgemm_translate_to_post_ops_list + ( + post_op_unparsed, post_op_list, + ( void* )c, ( void* )( &order_use ) + ); + + // Initialize a local runtime with global settings if necessary. Note + // that in the case that a runtime is passed in, we make a local copy. + rntm_t rntm_g; + bli_rntm_init_from_global( &rntm_g ); + bli_membrk_rntm_set_membrk( &rntm_g ); + +#ifdef BLIS_ENABLE_OPENMP + // Swapping inputs to induce row major computation for column major inputs. + if ( is_column_major == TRUE ) + { + lpgemm_bf16bf16f32of32_openmp_thread_decorator + ( + n, m, k, + b, rs_b, cs_b, mtag_b, + a, rs_a, cs_a, mtag_a, + ( float* )c, rs_c, cs_c, + alpha, beta, + &rntm_g, + post_op_list, TRUE + ); + } + else + { + lpgemm_bf16bf16f32of32_openmp_thread_decorator + ( + m, n, k, + a, rs_a, cs_a, mtag_a, + b, rs_b, cs_b, mtag_b, + ( float* )c, rs_c, cs_c, + alpha, beta, + &rntm_g, + post_op_list, TRUE + ); + } +#else + // Swapping inputs to induce row major computation for column major inputs. + if ( is_column_major == TRUE ) + { + lpgemm_bf16bf16f32of32_thread_decorator + ( + n, m, k, + b, rs_b, cs_b, mtag_b, + a, rs_a, cs_a, mtag_a, + ( float* )c, rs_c, cs_c, + alpha, beta, + &rntm_g, + post_op_list, TRUE + ); + } + else + { + lpgemm_bf16bf16f32of32_thread_decorator + ( + m, n, k, + a, rs_a, cs_a, mtag_a, + b, rs_b, cs_b, mtag_b, + ( float* )c, rs_c, cs_c, + alpha, beta, + &rntm_g, + post_op_list, TRUE + ); + } +#endif +} diff --git a/addon/aocl_gemm/aocl_gemm_bf16bf16f32of32.c b/addon/aocl_gemm/aocl_gemm_bf16bf16f32of32.c new file mode 100644 index 0000000000..8f87f4dff3 --- /dev/null +++ b/addon/aocl_gemm/aocl_gemm_bf16bf16f32of32.c @@ -0,0 +1,218 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" +#include "aocl_gemm_interface_apis.h" +#include "lpgemm_types.h" +#include "lpgemm_post_ops.h" +#include "lpgemm_thread_decor_openmp.h" +#include "lpgemm_5loop_interface_apis.h" +#include "lpgemm_config.h" +#include "lpgemm_utils.h" + +AOCL_GEMM_MATMUL(bfloat16,bfloat16,float,float,bf16bf16f32of32) +{ + trans_t blis_transa; + trans_t blis_transb; + + // Check if avx512_vnni ISA is supported, lpgemm matmul only works with it. + if ( bli_cpuid_is_avx512_bf16_supported() == FALSE ) + { + printf(" AVX512_BF16 ISA not supported by processor, cannot perform lpgemm.\n"); + return; // Error. + } + + /* Initialize BLIS. */ + bli_init_auto(); + + // Set MC, NC, KC, NR, MR. + aocl_lpgemm_init_global_cntx(); + + // Null check for pointers. + if ( ( a == NULL ) || ( b == NULL ) || ( c == NULL ) ) + { + return; // Error. + } + + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ + bli_param_map_netlib_to_blis_trans( transa, &blis_transa ); + bli_param_map_netlib_to_blis_trans( transb, &blis_transb ); + + /* Perform BLAS parameter checking. */ + // Transpose not supported. + if ( ( blis_transa != BLIS_NO_TRANSPOSE ) || + ( blis_transb != BLIS_NO_TRANSPOSE ) ) + { + return; // Error. + } + + // Sanitize order input. + char order_use = + ( ( order == 'r' ) || ( order == 'R' ) || + ( order == 'c' ) || ( order == 'C' ) ) ? + order : 'r'; + + bool is_row_major = ( ( order_use == 'r' ) || ( order_use == 'R' ) ); + bool is_column_major = ( ( order_use == 'c' ) || ( order_use == 'C' ) ); + + // Row major input expected with leading dimensions >= row stride. + if ( ( is_row_major == TRUE ) && + ( ( lda < k ) || ( ldb < n ) || ( ldc < n ) ) ) + { + return; // Error. + } + // Column major input expected with leading dimensions >= column stride. + else if ( ( is_column_major == TRUE ) && + ( ( lda < m ) || ( ldb < k ) || ( ldc < m ) ) ) + { + return; // Error. + } + + // Check if dimensions are valid. + if ( ( m <= 0) || ( n <= 0 ) || ( k <= 0 ) || + ( lda <= 0 ) || ( ldb <= 0 ) || ( ldc <= 0 ) ) + { + return; // Error. + } + + const inc_t rs_a = lda; + const inc_t cs_a = 1; + const inc_t rs_b = ldb; + const inc_t cs_b = 1; + const inc_t rs_c = ldc; + const inc_t cs_c = 1; + + AOCL_MEMORY_TAG mtag_a; + AOCL_MEMORY_TAG mtag_b; + + bli_param_map_char_to_lpmtag( mem_format_a, &mtag_a ); + bli_param_map_char_to_lpmtag( mem_format_b, &mtag_b ); + + // B matrix needs to be packed in a certain format in order to be loaded + // and used in bf16 instrution. As such the mtag_b always needs to be either + // packed or reordered. B matrix as it is (unpacked) cannot be used, and + // the mtag_b is set to packed to enable runtime packing. + if ( ( is_row_major == TRUE ) && ( mtag_b == UNPACKED ) ) + { + mtag_b = PACK; + } + // Inputs swapped in column major, A becomes B from kernel point of view. + else if ( ( is_column_major == TRUE ) && ( mtag_a == UNPACKED ) ) + { + mtag_a = PACK; + } + + // Only unpacked A supported now. + if ( ( is_row_major == TRUE ) && ( mtag_a != UNPACKED ) ) + { + return; // Error. + } + // Inputs swapped in column major, B becomes A from kernel point of view. + else if ( ( is_column_major == TRUE ) && ( mtag_b != UNPACKED ) ) + { + return; // Error. + } + + // Convert post op struct to post op linked list format. + lpgemm_post_op post_op_list[AOCL_MAX_POST_OPS]; + lpgemm_translate_to_post_ops_list + ( + post_op_unparsed, post_op_list, + ( void* )c, ( void* )( &order_use ) + ); + + // Initialize a local runtime with global settings if necessary. Note + // that in the case that a runtime is passed in, we make a local copy. + rntm_t rntm_g; + bli_rntm_init_from_global( &rntm_g ); + bli_membrk_rntm_set_membrk( &rntm_g ); + +#ifdef BLIS_ENABLE_OPENMP + // Swapping inputs to induce row major computation for column major inputs. + if ( is_column_major == TRUE ) + { + lpgemm_bf16bf16f32of32_openmp_thread_decorator + ( + n, m, k, + b, rs_b, cs_b, mtag_b, + a, rs_a, cs_a, mtag_a, + c, rs_c, cs_c, + alpha, beta, + &rntm_g, + post_op_list, FALSE + ); + } + else + { + lpgemm_bf16bf16f32of32_openmp_thread_decorator + ( + m, n, k, + a, rs_a, cs_a, mtag_a, + b, rs_b, cs_b, mtag_b, + c, rs_c, cs_c, + alpha, beta, + &rntm_g, + post_op_list, FALSE + ); + } +#else + // Swapping inputs to induce row major computation for column major inputs. + if ( is_column_major == TRUE ) + { + lpgemm_bf16bf16f32of32_thread_decorator + ( + n, m, k, + b, rs_b, cs_b, mtag_b, + a, rs_a, cs_a, mtag_a, + c, rs_c, cs_c, + alpha, beta, + &rntm_g, + post_op_list, FALSE + ); + } + else + { + lpgemm_bf16bf16f32of32_thread_decorator + ( + m, n, k, + a, rs_a, cs_a, mtag_a, + b, rs_b, cs_b, mtag_b, + c, rs_c, cs_c, + alpha, beta, + &rntm_g, + post_op_list, FALSE + ); + } +#endif +} diff --git a/addon/aocl_gemm/aocl_gemm_f32f32f32of32.c b/addon/aocl_gemm/aocl_gemm_f32f32f32of32.c new file mode 100644 index 0000000000..8366f746cb --- /dev/null +++ b/addon/aocl_gemm/aocl_gemm_f32f32f32of32.c @@ -0,0 +1,174 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" +#include "aocl_gemm_interface_apis.h" +#include "lpgemm_types.h" +#include "lpgemm_post_ops.h" +#include "lpgemm_thread_decor_openmp.h" +#include "lpgemm_utils.h" +#include "lpgemm_5loop_interface_apis.h" + +AOCL_GEMM_MATMUL(float,float,float,float,f32f32f32of32) +{ + trans_t blis_transa; + trans_t blis_transb; + + // Check if avx ISA is supported, lpgemm fp32 matmul only works with it. + if ( bli_cpuid_is_avx_supported() == FALSE ) + { + printf(" AVX2 ISA not supported by processor, cannot perform lpgemm.\n"); + return; // Error. + } + + /* Initialize BLIS. */ + bli_init_auto(); + + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); + AOCL_DTL_LOG_GEMM_INPUTS(AOCL_DTL_LEVEL_TRACE_1, *MKSTR(s), transa, transb, m, n, k,\ + (void*)&alpha, lda, ldb, (void*)&beta, ldc); + + // Null check for pointers. + if ( ( a == NULL ) || ( b == NULL ) || ( c == NULL ) ) + { + AOCL_DTL_TRACE_EXIT_ERR(AOCL_DTL_LEVEL_TRACE_1, \ + "Invalid pointers provided for input parameters."); + return; // Error. + } + + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ + bli_param_map_netlib_to_blis_trans( transa, &blis_transa ); + bli_param_map_netlib_to_blis_trans( transb, &blis_transb ); + + /* Perform BLAS parameter checking. */ + // Transpose not supported. + if ( ( blis_transa != BLIS_NO_TRANSPOSE ) || + ( blis_transb != BLIS_NO_TRANSPOSE ) ) + { + AOCL_DTL_TRACE_EXIT_ERR(AOCL_DTL_LEVEL_TRACE_1, \ + "Input matrix transpose not supported."); + return; // Error. + } + + // Sanitize order input. + char order_use = + ( ( order == 'r' ) || ( order == 'R' ) || + ( order == 'c' ) || ( order == 'C' ) ) ? + order : 'r'; + if ( ( order_use != 'r' ) && ( order_use != 'R' ) ) + { + return; // Only row major supported. + } + + // Row major input expected with leading dimensions equal to row stride. + if ( ( lda != k ) || ( ldb != n ) || ( ldc != n ) ) + { + AOCL_DTL_TRACE_EXIT_ERR(AOCL_DTL_LEVEL_TRACE_1, \ + "Column major and general stride not supported."); + return; // Error. + } + + // Check if dimensions are valid. + if ( ( m <= 0) || ( n <= 0 ) || ( k <= 0 ) || + ( lda <= 0 ) || ( ldb <= 0 ) || ( ldc <= 0 ) ) + { + AOCL_DTL_TRACE_EXIT_ERR(AOCL_DTL_LEVEL_TRACE_1, \ + "Invalid matrix dimensions."); + return; // Error. + } + + const inc_t rs_a = lda; + const inc_t cs_a = 1; + const inc_t rs_b = ldb; + const inc_t cs_b = 1; + const inc_t rs_c = ldc; + const inc_t cs_c = 1; + + AOCL_MEMORY_TAG mtag_a; + AOCL_MEMORY_TAG mtag_b; + + bli_param_map_char_to_lpmtag( mem_format_a, &mtag_a ); + bli_param_map_char_to_lpmtag( mem_format_b, &mtag_b ); + + // Only unreordered A supported now. + if ( mtag_a != UNPACKED ) + { + AOCL_DTL_TRACE_EXIT_ERR(AOCL_DTL_LEVEL_TRACE_1, \ + "A matrix packing/reordering not supported."); + return; // Error. + } + + // Convert post op struct to post op linked list format. + lpgemm_post_op post_op_list[AOCL_MAX_POST_OPS]; + lpgemm_translate_to_post_ops_list + ( + post_op_unparsed, post_op_list, + ( void* )c, ( void* )( &order_use ) + ); + + // Initialize a local runtime with global settings if necessary. Note + // that in the case that a runtime is passed in, we make a local copy. + rntm_t rntm_g; + bli_rntm_init_from_global( &rntm_g ); + bli_membrk_rntm_set_membrk( &rntm_g ); + +#ifdef BLIS_ENABLE_OPENMP + lpgemm_f32f32f32of32_openmp_thread_decorator + ( + m, n, k, + a, rs_a, cs_a, mtag_a, + b, rs_b, cs_b, mtag_b, + c, rs_c, cs_c, + alpha, beta, + &rntm_g, + post_op_list, FALSE + ); +#else + // Setting pack A by default for non open mp case. + bli_rntm_set_pack_a( 1, &rntm_g ); + + lpgemm_f32f32f32of32_thread_decorator + ( + m, n, k, + a, rs_a, cs_a, mtag_a, + b, rs_b, cs_b, mtag_b, + c, rs_c, cs_c, + alpha, beta, + &rntm_g, + post_op_list, FALSE + ); +#endif + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); +} diff --git a/addon/aocl_gemm/aocl_gemm_f32f32f32of32_utils.c b/addon/aocl_gemm/aocl_gemm_f32f32f32of32_utils.c new file mode 100644 index 0000000000..948c1383de --- /dev/null +++ b/addon/aocl_gemm/aocl_gemm_f32f32f32of32_utils.c @@ -0,0 +1,250 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" +#include "aocl_gemm_interface_apis.h" +#include "lpgemm_utils.h" + +AOCL_GEMM_GET_REORDER_BUF_SIZE(f32f32f32of32) +{ + if ( ( k <= 0 ) || ( n <= 0 ) ) + { + return 0; // Error. + } + + // Check if avx ISA is supported, lpgemm fp32 matmul only works with it. + if ( bli_cpuid_is_avx_supported() == FALSE ) + { + printf(" AVX2 ISA not supported by processor, cannot perform lpgemm.\n"); + return 0; // Error. + } + + /* Initialize BLIS. */ + bli_init_auto(); + + // Query the global cntx. + cntx_t* cntx = bli_gks_query_cntx(); + + num_t dt = BLIS_FLOAT; + + AOCL_MATRIX_TYPE input_mat_type; + bli_param_map_char_to_lpmat_type( mat_type, &input_mat_type ); + + if ( input_mat_type == A_MATRIX ) + { + return 0; // A reorder not supported. + } + + const dim_t NR = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_NR, cntx ); + + // Extra space since packing does width in multiples of NR. + const dim_t n_reorder = ( ( n + NR - 1 ) / NR ) * NR; + + siz_t size_req = sizeof( float ) * k * n_reorder; + + return size_req; +} + +// Pack B into row stored column panels. +AOCL_GEMM_REORDER(float,f32f32f32of32) +{ + if ( ( input_buf_addr == NULL ) || ( reorder_buf_addr == NULL ) || + ( k <= 0 ) || ( n <= 0 ) || ( ldb < n ) ) + { + return; // Error. + } + + // Check if avx ISA is supported, lpgemm fp32 matmul only works with it. + if ( bli_cpuid_is_avx_supported() == FALSE ) + { + printf(" AVX2 ISA not supported by processor, cannot perform lpgemm.\n"); + return; // Error. + } + + /* Initialize BLIS. */ + bli_init_auto(); + + // Query the global cntx. + cntx_t* cntx = bli_gks_query_cntx(); + + num_t dt = BLIS_FLOAT; + + AOCL_MATRIX_TYPE input_mat_type; + bli_param_map_char_to_lpmat_type( mat_type, &input_mat_type ); + + if ( input_mat_type == A_MATRIX ) + { + return; // A reorder not supported. + } + + const dim_t NC = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_NC, cntx ); + const dim_t KC = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_KC, cntx ); + const dim_t NR = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_NR, cntx ); + + // Only supports row major packing now. + inc_t rs_b = ldb; + inc_t cs_b = 1; + + inc_t rs_p = NR; + + float one_local = *PASTEMAC(s,1); + float* restrict kappa_cast = &one_local; + + // Set the schema to "row stored column panels" to indicate packing to + // conventional column-stored row panels. + pack_t schema = BLIS_PACKED_COL_PANELS; + trans_t transc = BLIS_NO_TRANSPOSE; + conj_t conjc = bli_extract_conj( transc ); + + // Initialize a local runtime with global settings if necessary. Note + // that in the case that a runtime is passed in, we make a local copy. + rntm_t rntm_g; + bli_rntm_init_from_global( &rntm_g ); + + dim_t n_threads = bli_rntm_num_threads( &rntm_g ); + n_threads = ( n_threads > 0 ) ? n_threads : 1; + +#ifdef BLIS_ENABLE_OPENMP + _Pragma( "omp parallel num_threads(n_threads)" ) + { + // Initialise a local thrinfo obj for work split across threads. + thrinfo_t thread_jc; + bli_thrinfo_set_n_way( n_threads, &thread_jc ); + bli_thrinfo_set_work_id( omp_get_thread_num(), &thread_jc ); +#else + { + // Initialise a local thrinfo obj for work split across threads. + thrinfo_t thread_jc; + bli_thrinfo_set_n_way( 1, &thread_jc ); + bli_thrinfo_set_work_id( 0, &thread_jc ); +#endif + // Compute the JC loop thread range for the current thread. Per thread + // gets multiple of NR columns. + dim_t jc_start, jc_end; + bli_thread_range_sub( &thread_jc, n, NR, FALSE, &jc_start, &jc_end ); + + for ( dim_t jc = jc_start; jc < jc_end; jc += NC ) + { + dim_t nc0 = bli_min( ( jc_end - jc ), NC ); + + dim_t jc_cur_loop = jc; + dim_t jc_cur_loop_rem = 0; + dim_t n_sub_updated; + + get_B_panel_reordered_start_offset_width + ( + jc, n, NC, NR, + &jc_cur_loop, &jc_cur_loop_rem, + &nc0, &n_sub_updated + ); + + // Compute the total number of iterations we'll need. + dim_t n_iter = ( nc0 + NR - 1 ) / NR; + + for ( dim_t pc = 0; pc < k; pc += KC ) + { + dim_t kc0 = bli_min( ( k - pc ), KC ); + inc_t ps_p = kc0 * NR; + + const float* b_temp = input_buf_addr + ( jc * cs_b ) + ( pc * rs_b ); + + // The offsets are calculated in such a way that it resembles + // the reorder buffer traversal in single threaded reordering. + // The panel boundaries (KCxNC) remain as it is accessed in + // single thread, and as a consequence a thread with jc_start + // inside the panel cannot consider NC range for reorder. It + // has to work with NC' < NC, and the offset is calulated using + // prev NC panels spanning k dim + cur NC panel spaning pc loop + // cur iteration + (NC - NC') spanning current kc0 (<= KC). + // + //Eg: Consider the following reordered buffer diagram: + // t1 t2 + // | | + // | |..NC..| + // | | | + // |.NC. |.NC. |NC'|NC" + // pc=0-+-----+-----+---+--+ + // KC| | | | | + // | 1 | 3 | 5 | + // pc=KC-+-----+-----+---st-+ + // KC| | | | | + // | 2 | 4 | 6 | 7| + // pc=k=2KC-+-----+-----+---+--+ + // |jc=0 |jc=NC|jc=2NC| + // + // The numbers 1,2..6,7 denotes the order in which reordered + // KCxNC blocks are stored in memory, ie: block 1 followed by 2 + // followed by 3, etc. Given two threads t1 and t2, and t2 needs + // to acces point st in the reorder buffer to write the data: + // The offset calulation logic will be: + // jc_cur_loop = 2NC, jc_cur_loop_rem = NC', pc = KC, + // n_sub_updated = NC, k = 2KC, kc0_updated = KC + // + // st = ( jc_cur_loop * k ) + // + ( n_sub_updated * pc ) + // + ( NC' * kc0_updated) + float* p_temp = reorder_buf_addr + ( jc_cur_loop * k ) + + ( n_sub_updated * pc ) + ( jc_cur_loop_rem * kc0 ); + + dim_t jr, it; + // Iterate over every logical micropanel in the source matrix. + for ( jr = 0, it = 0; it < n_iter; jr += NR, it += 1 ) + { + dim_t panel_dim_i = bli_min( NR, nc0 - jr ); + + const float* b_use = b_temp + ( jr * cs_b ); + float* p_use = p_temp; + + PASTEMAC(s,packm_cxk) + ( + conjc, + schema, + panel_dim_i, + NR, + kc0, + kc0, + kappa_cast, + ( float* )b_use, cs_b, rs_b, + p_use, rs_p, + cntx + ); + + p_temp += ps_p; + } + } + + adjust_B_panel_reordered_jc( &jc, jc_cur_loop ); + } + } +} diff --git a/addon/aocl_gemm/aocl_gemm_interface_apis.h b/addon/aocl_gemm/aocl_gemm_interface_apis.h new file mode 100644 index 0000000000..40101cbe6a --- /dev/null +++ b/addon/aocl_gemm/aocl_gemm_interface_apis.h @@ -0,0 +1,107 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef AOCL_GEMM_INTERFACE_H +#define AOCL_GEMM_INTERFACE_H + +#include "aocl_gemm_post_ops.h" +#include "aocl_bf16_type.h" + +// Returns the size of buffer in bytes required for the reordered matrix. +#define AOCL_GEMM_GET_REORDER_BUF_SIZE(LP_SFX) \ +BLIS_EXPORT_ADDON siz_t aocl_get_reorder_buf_size_ ## LP_SFX \ + ( \ + const char mat_type, \ + const dim_t k, \ + const dim_t n \ + ) \ + +AOCL_GEMM_GET_REORDER_BUF_SIZE(f32f32f32of32); +AOCL_GEMM_GET_REORDER_BUF_SIZE(u8s8s32os32); +AOCL_GEMM_GET_REORDER_BUF_SIZE(u8s8s16os16); +AOCL_GEMM_GET_REORDER_BUF_SIZE(bf16bf16f32of32); + +// Performs reordering of input matrix. Reordering is the process of packing +// the entire matrix upfront, so that the benefits of packed matrix is obtained +// without incurring the packing costs during matmul computation. +#define AOCL_GEMM_REORDER(B_type,LP_SFX) \ +BLIS_EXPORT_ADDON void aocl_reorder_ ## LP_SFX \ + ( \ + const char mat_type, \ + const B_type* input_buf_addr, \ + B_type* reorder_buf_addr, \ + const dim_t k, \ + const dim_t n, \ + const dim_t ldb \ + ) \ + +AOCL_GEMM_REORDER(float,f32f32f32of32); +AOCL_GEMM_REORDER(int8_t,u8s8s32os32); +AOCL_GEMM_REORDER(int8_t,u8s8s16os16); +AOCL_GEMM_REORDER(bfloat16,bf16bf16f32of32); + +// Only supports matrices in row major format. This api can perform gemm with +// both normal as well as reordered B matrix as opposesd to sgemm (only +// supports former). This api can be considered analogous to packed sgemm api. +#define AOCL_GEMM_MATMUL(A_type,B_type,C_type,Sum_type,LP_SFX) \ +BLIS_EXPORT_ADDON void aocl_gemm_ ## LP_SFX \ + ( \ + const char order, \ + const char transa, \ + const char transb, \ + const dim_t m, \ + const dim_t n, \ + const dim_t k, \ + const Sum_type alpha, \ + const A_type* a, \ + const dim_t lda, \ + const char mem_format_a, \ + const B_type* b, \ + const dim_t ldb, \ + const char mem_format_b, \ + const Sum_type beta, \ + C_type* c, \ + const dim_t ldc, \ + aocl_post_op* post_op_unparsed \ + ) \ + +AOCL_GEMM_MATMUL(float,float,float,float,f32f32f32of32); +AOCL_GEMM_MATMUL(uint8_t,int8_t,int32_t,int32_t,u8s8s32os32); +AOCL_GEMM_MATMUL(uint8_t,int8_t,int16_t,int16_t,u8s8s16os16); +AOCL_GEMM_MATMUL(bfloat16,bfloat16,float,float,bf16bf16f32of32); +AOCL_GEMM_MATMUL(uint8_t,int8_t,int8_t,int32_t,u8s8s32os8); +AOCL_GEMM_MATMUL(uint8_t,int8_t,int8_t,int16_t,u8s8s16os8); +AOCL_GEMM_MATMUL(bfloat16,bfloat16,bfloat16,float,bf16bf16f32obf16); + +#endif // AOCL_GEMM_INTERFACE_H diff --git a/addon/aocl_gemm/aocl_gemm_post_ops.h b/addon/aocl_gemm/aocl_gemm_post_ops.h new file mode 100644 index 0000000000..86034598ac --- /dev/null +++ b/addon/aocl_gemm/aocl_gemm_post_ops.h @@ -0,0 +1,95 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef AOCL_GEMM_POST_OPS_H +#define AOCL_GEMM_POST_OPS_H + +#define AOCL_MAX_POST_OPS 5 + +typedef enum +{ + RELU = 0, + PRELU = 1, +} AOCL_ELT_ALGO_TYPE; + +typedef enum +{ + SUM = 1, + ELTWISE = 2, + BIAS = 3, + SCALE = 4, +} AOCL_POST_OP_TYPE; + +typedef struct +{ + void* alpha; + void* beta; + AOCL_ELT_ALGO_TYPE algo_type; +} aocl_eltwise_algo; + +typedef struct +{ + bool is_power_of_2; + void* scale_factor; + void* buff; + void* zero_point; +} aocl_post_op_sum; // Also use for scale. + +typedef struct +{ + bool is_power_of_2; + void* scale_factor; + aocl_eltwise_algo algo; +} aocl_post_op_eltwise; + +typedef struct +{ + void* bias; +} aocl_post_op_bias; + +typedef struct +{ + aocl_post_op_sum sum; + aocl_post_op_eltwise eltwise; + aocl_post_op_bias bias; + + // eg: seq_length = 2 + dim_t seq_length; + + // eg: seq_vector[0] = BIAS, seq_vector[1] = ELTWISE means bias followed + // by eltwise(relu, if AOCL_ELT_ALGO_TYPE = 1). + AOCL_POST_OP_TYPE* seq_vector; +} aocl_post_op; + +#endif //AOCL_GEMM_POST_OPS_H diff --git a/addon/aocl_gemm/aocl_gemm_u8s8s16os16.c b/addon/aocl_gemm/aocl_gemm_u8s8s16os16.c new file mode 100644 index 0000000000..1c6b0899ad --- /dev/null +++ b/addon/aocl_gemm/aocl_gemm_u8s8s16os16.c @@ -0,0 +1,167 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" +#include "aocl_gemm_interface_apis.h" +#include "lpgemm_types.h" +#include "lpgemm_5loop_interface_apis.h" +#include "lpgemm_config.h" +#include "lpgemm_utils.h" +#include "lpgemm_thread_decor_openmp.h" +#include "lpgemm_post_ops.h" + +AOCL_GEMM_MATMUL(uint8_t,int8_t,int16_t,int16_t,u8s8s16os16) +{ + trans_t blis_transa; + trans_t blis_transb; + + // Check if avx ISA is supported, lpgemm u8s8s16os16 matmul only works with it. + if ( bli_cpuid_is_avx_supported() == FALSE ) + { + printf(" AVX2 ISA not supported by processor, cannot perform lpgemm.\n"); + return; // Error. + } + + /* Initialize BLIS. */ + bli_init_auto(); + + // Set MC, NC, KC, NR, MR. + aocl_lpgemm_init_global_cntx(); + + // Null check for pointers. + if ((a == NULL) || (b == NULL) || (c == NULL)) + { + return; // Error. + } + + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ + bli_param_map_netlib_to_blis_trans(transa, &blis_transa); + bli_param_map_netlib_to_blis_trans(transb, &blis_transb); + + /* Perform BLAS parameter checking. */ + // Transpose not supported. + if ( ( blis_transa != BLIS_NO_TRANSPOSE ) || + ( blis_transb != BLIS_NO_TRANSPOSE ) ) + { + return; // Error. + } + + // Sanitize order input. + char order_use = + ( ( order == 'r' ) || ( order == 'R' ) || + ( order == 'c' ) || ( order == 'C' ) ) ? + order : 'r'; + if ( ( order_use != 'r' ) && ( order_use != 'R' ) ) + { + return; // Only row major supported. + } + + // Row major input expected with leading dimensions equal to row stride. + if ((lda != k) || (ldb != n) || (ldc != n)) + { + return; // Error. + } + + // Check if dimensions are valid. + if ((m <= 0) || (n <= 0) || (k <= 0) || (lda <= 0) || (ldb <= 0) || (ldc <= 0)) + { + return; // Error. + } + + const inc_t rs_a = lda; + const inc_t cs_a = 1; + const inc_t rs_b = ldb; + const inc_t cs_b = 1; + const inc_t rs_c = ldc; + const inc_t cs_c = 1; + + AOCL_MEMORY_TAG mtag_a; + AOCL_MEMORY_TAG mtag_b; + + bli_param_map_char_to_lpmtag(mem_format_a, &mtag_a); + bli_param_map_char_to_lpmtag(mem_format_b, &mtag_b); + + // B matrix needs to be packed in a certain format in order to be loaded + // and used in VNNI instrution. As such the mtag_b always needs to be either + // packed or reordered. B matrix as it is (unpacked) cannot be used, and + // the mtag_b is set to packed to enable runtime packing. + if (mtag_b == UNPACKED) + { + mtag_b = PACK; + } + + // Only unpacked A supported now. + if (mtag_a != UNPACKED) + { + return; // Error. + } + + // Convert post op struct to post op linked list format. + lpgemm_post_op post_op_list[AOCL_MAX_POST_OPS]; + lpgemm_translate_to_post_ops_list + ( + post_op_unparsed, post_op_list, + ( void* )c, ( void* )( &order_use ) + ); + + // Initialize a local runtime with global settings if necessary. Note + // that in the case that a runtime is passed in, we make a local copy. + rntm_t rntm_g; + bli_rntm_init_from_global(&rntm_g); + bli_membrk_rntm_set_membrk(&rntm_g); + +#ifdef BLIS_ENABLE_OPENMP + lpgemm_u8s8s16o16_openmp_thread_decorator + ( + m, n, k, + a, rs_a, cs_a, mtag_a, + b, rs_b, cs_b, mtag_b, + c, rs_c, cs_c, + alpha, beta, + &rntm_g, + post_op_list, FALSE + ); +#else + lpgemm_u8s8s16o16_thread_decorator + ( + m, n, k, + a, rs_a, cs_a, mtag_a, + b, rs_b, cs_b, mtag_b, + c, rs_c, cs_c, + alpha, beta, + &rntm_g, + post_op_list, FALSE + ); +#endif +} diff --git a/addon/aocl_gemm/aocl_gemm_u8s8s16os16_utils.c b/addon/aocl_gemm/aocl_gemm_u8s8s16os16_utils.c new file mode 100644 index 0000000000..5cadd206d5 --- /dev/null +++ b/addon/aocl_gemm/aocl_gemm_u8s8s16os16_utils.c @@ -0,0 +1,126 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" +#include "aocl_gemm_interface_apis.h" +#include "lpgemm_types.h" +#include "lpgemm_config.h" +#include "lpgemm_utils.h" +#include "lpgemm_reorder_s16.h" + +AOCL_GEMM_GET_REORDER_BUF_SIZE(u8s8s16os16) +{ + if ((k <= 0) || (n <= 0)) + { + return 0; // Error. + } + + // Check if avx ISA is supported, lpgemm u8s8s16os16 matmul only works with it. + if ( bli_cpuid_is_avx_supported() == FALSE ) + { + printf(" AVX2 ISA not supported by processor, cannot perform lpgemm.\n"); + return 0; // Error. + } + + /* Initialize BLIS. */ + bli_init_auto(); + + // Set MC, NC, KC, NR, MR. + aocl_lpgemm_init_global_cntx(); + + AOCL_MATRIX_TYPE input_mat_type; + bli_param_map_char_to_lpmat_type(mat_type, &input_mat_type); + + if (input_mat_type == A_MATRIX) + { + return 0; // A reorder not supported. + } + + // Extra space since packing does width in multiples of 16. The vpmaddubsw + // instruction can be used as long as atleast one ymm register can be fully + // loaded; and since k_dim needs to be at least 2, having n_dim atleast 16 + // should give 2x16=32 elements, enough for 1 ymm register.The padding is + // not rounded to NR (=16), since that would result in memory wastage. + dim_t n_reorder = make_multiple_of_n(n, 16); + + // Extra space since packing does length in multiples of 2. + dim_t k_reorder = make_multiple_of_n(k, 2); + + siz_t size_req = sizeof(int8_t) * k_reorder * n_reorder; + + return size_req; +} + +AOCL_GEMM_REORDER(int8_t,u8s8s16os16) +{ + if ((input_buf_addr == NULL) || (reorder_buf_addr == NULL) || + (k <= 0) || (n <= 0) || (ldb < n)) + { + return; // Error. + } + + // Check if avx ISA is supported, lpgemm u8s8s16os16 matmul only works with it. + if ( bli_cpuid_is_avx_supported() == FALSE ) + { + printf(" AVX2 ISA not supported by processor, cannot perform lpgemm.\n"); + return; // Error. + } + + /* Initialize BLIS. */ + bli_init_auto(); + + // Set MC, NC, KC, NR, MR. + aocl_lpgemm_init_global_cntx(); + + AOCL_MATRIX_TYPE input_mat_type; + bli_param_map_char_to_lpmat_type(mat_type, &input_mat_type); + + if (input_mat_type == A_MATRIX) + { + return; // A reorder not supported. + } + + // Create dummy b_reorder obj. + lpgemm_obj_t b_reorder; + b_reorder.storage.aligned_buffer = reorder_buf_addr; + + // Create dummy original b obj; + lpgemm_obj_t b; + b.storage.aligned_buffer = (void *)input_buf_addr; + b.rs = ldb; + b.width = n; + b.length = k; + + aocl_reorderb_nr32_u8s8s16o16(&b, &b_reorder); +} diff --git a/addon/aocl_gemm/aocl_gemm_u8s8s16os8.c b/addon/aocl_gemm/aocl_gemm_u8s8s16os8.c new file mode 100644 index 0000000000..fed10c1e01 --- /dev/null +++ b/addon/aocl_gemm/aocl_gemm_u8s8s16os8.c @@ -0,0 +1,167 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" +#include "aocl_gemm_interface_apis.h" +#include "lpgemm_types.h" +#include "lpgemm_5loop_interface_apis.h" +#include "lpgemm_config.h" +#include "lpgemm_utils.h" +#include "lpgemm_thread_decor_openmp.h" +#include "lpgemm_post_ops.h" + +AOCL_GEMM_MATMUL(uint8_t,int8_t,int8_t,int16_t,u8s8s16os8) +{ + trans_t blis_transa; + trans_t blis_transb; + + // Check if avx ISA is supported, lpgemm u8s8s16os16 matmul only works with it. + if ( bli_cpuid_is_avx_supported() == FALSE ) + { + printf(" AVX2 ISA not supported by processor, cannot perform lpgemm.\n"); + return; // Error. + } + + /* Initialize BLIS. */ + bli_init_auto(); + + // Set MC, NC, KC, NR, MR. + aocl_lpgemm_init_global_cntx(); + + // Null check for pointers. + if ((a == NULL) || (b == NULL) || (c == NULL)) + { + return; // Error. + } + + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ + bli_param_map_netlib_to_blis_trans(transa, &blis_transa); + bli_param_map_netlib_to_blis_trans(transb, &blis_transb); + + /* Perform BLAS parameter checking. */ + // Transpose not supported. + if ( ( blis_transa != BLIS_NO_TRANSPOSE ) || + ( blis_transb != BLIS_NO_TRANSPOSE ) ) + { + return; // Error. + } + + // Sanitize order input. + char order_use = + ( ( order == 'r' ) || ( order == 'R' ) || + ( order == 'c' ) || ( order == 'C' ) ) ? + order : 'r'; + if ( ( order_use != 'r' ) && ( order_use != 'R' ) ) + { + return; // Only row major supported. + } + + // Row major input expected with leading dimensions equal to row stride. + if ((lda != k) || (ldb != n) || (ldc != n)) + { + return; // Error. + } + + // Check if dimensions are valid. + if ((m <= 0) || (n <= 0) || (k <= 0) || (lda <= 0) || (ldb <= 0) || (ldc <= 0)) + { + return; // Error. + } + + const inc_t rs_a = lda; + const inc_t cs_a = 1; + const inc_t rs_b = ldb; + const inc_t cs_b = 1; + const inc_t rs_c = ldc; + const inc_t cs_c = 1; + + AOCL_MEMORY_TAG mtag_a; + AOCL_MEMORY_TAG mtag_b; + + bli_param_map_char_to_lpmtag(mem_format_a, &mtag_a); + bli_param_map_char_to_lpmtag(mem_format_b, &mtag_b); + + // B matrix needs to be packed in a certain format in order to be loaded + // and used in VNNI instrution. As such the mtag_b always needs to be either + // packed or reordered. B matrix as it is (unpacked) cannot be used, and + // the mtag_b is set to packed to enable runtime packing. + if (mtag_b == UNPACKED) + { + mtag_b = PACK; + } + + // Only unpacked A supported now. + if (mtag_a != UNPACKED) + { + return; // Error. + } + + // Convert post op struct to post op linked list format. + lpgemm_post_op post_op_list[AOCL_MAX_POST_OPS]; + lpgemm_translate_to_post_ops_list + ( + post_op_unparsed, post_op_list, + ( void* )c, ( void* )( &order_use ) + ); + + // Initialize a local runtime with global settings if necessary. Note + // that in the case that a runtime is passed in, we make a local copy. + rntm_t rntm_g; + bli_rntm_init_from_global(&rntm_g); + bli_membrk_rntm_set_membrk(&rntm_g); + +#ifdef BLIS_ENABLE_OPENMP + lpgemm_u8s8s16o16_openmp_thread_decorator + ( + m, n, k, + a, rs_a, cs_a, mtag_a, + b, rs_b, cs_b, mtag_b, + ( int16_t* )c, rs_c, cs_c, + alpha, beta, + &rntm_g, + post_op_list, TRUE + ); +#else + lpgemm_u8s8s16o16_thread_decorator + ( + m, n, k, + a, rs_a, cs_a, mtag_a, + b, rs_b, cs_b, mtag_b, + ( int16_t* )c, rs_c, cs_c, + alpha, beta, + &rntm_g, + post_op_list, TRUE + ); +#endif +} diff --git a/addon/aocl_gemm/aocl_gemm_u8s8s32os32.c b/addon/aocl_gemm/aocl_gemm_u8s8s32os32.c new file mode 100644 index 0000000000..39fd49bca4 --- /dev/null +++ b/addon/aocl_gemm/aocl_gemm_u8s8s32os32.c @@ -0,0 +1,168 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" +#include "aocl_gemm_interface_apis.h" +#include "lpgemm_types.h" +#include "lpgemm_post_ops.h" +#include "lpgemm_thread_decor_openmp.h" +#include "lpgemm_5loop_interface_apis.h" +#include "lpgemm_config.h" +#include "lpgemm_utils.h" + +AOCL_GEMM_MATMUL(uint8_t,int8_t,int32_t,int32_t,u8s8s32os32) +{ + trans_t blis_transa; + trans_t blis_transb; + + // Check if avx512_vnni ISA is supported, lpgemm matmul only works with it. + if ( bli_cpuid_is_avx512vnni_supported() == FALSE ) + { + printf(" AVX512_VNNI ISA not supported by processor, cannot perform lpgemm.\n"); + return; // Error. + } + + /* Initialize BLIS. */ + bli_init_auto(); + + // Set MC, NC, KC, NR, MR. + aocl_lpgemm_init_global_cntx(); + + // Null check for pointers. + if ( ( a == NULL ) || ( b == NULL ) || ( c == NULL ) ) + { + return; // Error. + } + + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ + bli_param_map_netlib_to_blis_trans( transa, &blis_transa ); + bli_param_map_netlib_to_blis_trans( transb, &blis_transb ); + + /* Perform BLAS parameter checking. */ + // Transpose not supported. + if ( ( blis_transa != BLIS_NO_TRANSPOSE ) || + ( blis_transb != BLIS_NO_TRANSPOSE ) ) + { + return; // Error. + } + + // Sanitize order input. + char order_use = + ( ( order == 'r' ) || ( order == 'R' ) || + ( order == 'c' ) || ( order == 'C' ) ) ? + order : 'r'; + if ( ( order_use != 'r' ) && ( order_use != 'R' ) ) + { + return; // Only row major supported. + } + + // Row major input expected with leading dimensions equal to row stride. + if ( ( lda != k ) || ( ldb != n ) || ( ldc != n ) ) + { + return; // Error. + } + + // Check if dimensions are valid. + if ( ( m <= 0) || ( n <= 0 ) || ( k <= 0 ) || + ( lda <= 0 ) || ( ldb <= 0 ) || ( ldc <= 0 ) ) + { + return; // Error. + } + + const inc_t rs_a = lda; + const inc_t cs_a = 1; + const inc_t rs_b = ldb; + const inc_t cs_b = 1; + const inc_t rs_c = ldc; + const inc_t cs_c = 1; + + AOCL_MEMORY_TAG mtag_a; + AOCL_MEMORY_TAG mtag_b; + + bli_param_map_char_to_lpmtag( mem_format_a, &mtag_a ); + bli_param_map_char_to_lpmtag( mem_format_b, &mtag_b ); + + // B matrix needs to be packed in a certain format in order to be loaded + // and used in VNNI instrution. As such the mtag_b always needs to be either + // packed or reordered. B matrix as it is (unpacked) cannot be used, and + // the mtag_b is set to packed to enable runtime packing. + if ( mtag_b == UNPACKED ) + { + mtag_b = PACK; + } + + // Only unpacked A supported now. + if ( mtag_a != UNPACKED ) + { + return; // Error. + } + + // Convert post op struct to post op linked list format. + lpgemm_post_op post_op_list[AOCL_MAX_POST_OPS]; + lpgemm_translate_to_post_ops_list + ( + post_op_unparsed, post_op_list, + ( void* )c, ( void* )( &order_use ) + ); + + // Initialize a local runtime with global settings if necessary. Note + // that in the case that a runtime is passed in, we make a local copy. + rntm_t rntm_g; + bli_rntm_init_from_global( &rntm_g ); + bli_membrk_rntm_set_membrk( &rntm_g ); + +#ifdef BLIS_ENABLE_OPENMP + lpgemm_u8s8s32o32_openmp_thread_decorator + ( + m, n, k, + a, rs_a, cs_a, mtag_a, + b, rs_b, cs_b, mtag_b, + c, rs_c, cs_c, + alpha, beta, + &rntm_g, + post_op_list, FALSE + ); +#else + lpgemm_u8s8s32o32_thread_decorator + ( + m, n, k, + a, rs_a, cs_a, mtag_a, + b, rs_b, cs_b, mtag_b, + c, rs_c, cs_c, + alpha, beta, + &rntm_g, + post_op_list, FALSE + ); +#endif +} diff --git a/addon/aocl_gemm/aocl_gemm_u8s8s32os32_utils.c b/addon/aocl_gemm/aocl_gemm_u8s8s32os32_utils.c new file mode 100644 index 0000000000..11f9f6937a --- /dev/null +++ b/addon/aocl_gemm/aocl_gemm_u8s8s32os32_utils.c @@ -0,0 +1,126 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" +#include "aocl_gemm_interface_apis.h" +#include "lpgemm_types.h" +#include "lpgemm_config.h" +#include "lpgemm_utils.h" +#include "lpgemm_reorder.h" + +AOCL_GEMM_GET_REORDER_BUF_SIZE(u8s8s32os32) +{ + if ( ( k <= 0 ) || ( n <= 0 ) ) + { + return 0; // Error. + } + + // Check if avx512_vnni ISA is supported, lpgemm matmul only works with it. + if ( bli_cpuid_is_avx512vnni_supported() == FALSE ) + { + printf(" AVX512_VNNI ISA not supported by processor, cannot perform lpgemm.\n"); + return 0; // Error. + } + + /* Initialize BLIS. */ + bli_init_auto(); + + // Set MC, NC, KC, NR, MR. + aocl_lpgemm_init_global_cntx(); + + AOCL_MATRIX_TYPE input_mat_type; + bli_param_map_char_to_lpmat_type( mat_type, &input_mat_type ); + + if ( input_mat_type == A_MATRIX ) + { + return 0; // A reorder not supported. + } + + // Extra space since packing does width in multiples of 16. The vnni + // instruction can be used as long as atleast one zmm register can be fully + // loaded; and since k_dim needs to be atleast 4, having n_dim atleast 16 + // should give 4x16=64 elements, enough for 1 zmm register.The padding is + // not rounded to NR (=64), since that would result in memory wastage. + dim_t n_reorder = make_multiple_of_n( n, 16 ); + + // Extra space since packing does length in multiples of 4. + dim_t k_reorder = make_multiple_of_n( k, 4 ); + + siz_t size_req = sizeof( int8_t ) * k_reorder * n_reorder; + + return size_req; +} + +AOCL_GEMM_REORDER(int8_t,u8s8s32os32) +{ + if ( ( input_buf_addr == NULL ) || ( reorder_buf_addr == NULL ) || + ( k <= 0 ) || ( n <= 0 ) || ( ldb < n ) ) + { + return; // Error. + } + + // Check if avx512_vnni ISA is supported, lpgemm matmul only works with it. + if ( bli_cpuid_is_avx512vnni_supported() == FALSE ) + { + printf(" AVX512_VNNI ISA not supported by processor, cannot perform lpgemm.\n"); + return; // Error. + } + + /* Initialize BLIS. */ + bli_init_auto(); + + // Set MC, NC, KC, NR, MR. + aocl_lpgemm_init_global_cntx(); + + AOCL_MATRIX_TYPE input_mat_type; + bli_param_map_char_to_lpmat_type( mat_type, &input_mat_type ); + + if ( input_mat_type == A_MATRIX ) + { + return; // A reorder not supported. + } + + // Create dummy b_reorder obj. + lpgemm_obj_t b_reorder; + b_reorder.storage.aligned_buffer = reorder_buf_addr; + + // Create dummy original b obj; + lpgemm_obj_t b; + b.storage.aligned_buffer = ( void* )input_buf_addr; + b.rs = ldb; + b.width = n; + b.length = k; + + reorderb_nr64_u8s8s32o32( &b, &b_reorder ); +} diff --git a/addon/aocl_gemm/aocl_gemm_u8s8s32os8.c b/addon/aocl_gemm/aocl_gemm_u8s8s32os8.c new file mode 100644 index 0000000000..e4a4ce3f2d --- /dev/null +++ b/addon/aocl_gemm/aocl_gemm_u8s8s32os8.c @@ -0,0 +1,168 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" +#include "aocl_gemm_interface_apis.h" +#include "lpgemm_types.h" +#include "lpgemm_post_ops.h" +#include "lpgemm_thread_decor_openmp.h" +#include "lpgemm_5loop_interface_apis.h" +#include "lpgemm_config.h" +#include "lpgemm_utils.h" + +AOCL_GEMM_MATMUL(uint8_t,int8_t,int8_t,int32_t,u8s8s32os8) +{ + trans_t blis_transa; + trans_t blis_transb; + + // Check if avx512_vnni ISA is supported, lpgemm matmul only works with it. + if ( bli_cpuid_is_avx512vnni_supported() == FALSE ) + { + printf(" AVX512_VNNI ISA not supported by processor, cannot perform lpgemm.\n"); + return; // Error. + } + + /* Initialize BLIS. */ + bli_init_auto(); + + // Set MC, NC, KC, NR, MR. + aocl_lpgemm_init_global_cntx(); + + // Null check for pointers. + if ( ( a == NULL ) || ( b == NULL ) || ( c == NULL ) ) + { + return; // Error. + } + + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ + bli_param_map_netlib_to_blis_trans( transa, &blis_transa ); + bli_param_map_netlib_to_blis_trans( transb, &blis_transb ); + + /* Perform BLAS parameter checking. */ + // Transpose not supported. + if ( ( blis_transa != BLIS_NO_TRANSPOSE ) || + ( blis_transb != BLIS_NO_TRANSPOSE ) ) + { + return; // Error. + } + + // Sanitize order input. + char order_use = + ( ( order == 'r' ) || ( order == 'R' ) || + ( order == 'c' ) || ( order == 'C' ) ) ? + order : 'r'; + if ( ( order_use != 'r' ) && ( order_use != 'R' ) ) + { + return; // Only row major supported. + } + + // Row major input expected with leading dimensions equal to row stride. + if ( ( lda != k ) || ( ldb != n ) || ( ldc != n ) ) + { + return; // Error. + } + + // Check if dimensions are valid. + if ( ( m <= 0) || ( n <= 0 ) || ( k <= 0 ) || + ( lda <= 0 ) || ( ldb <= 0 ) || ( ldc <= 0 ) ) + { + return; // Error. + } + + const inc_t rs_a = lda; + const inc_t cs_a = 1; + const inc_t rs_b = ldb; + const inc_t cs_b = 1; + const inc_t rs_c = ldc; + const inc_t cs_c = 1; + + AOCL_MEMORY_TAG mtag_a; + AOCL_MEMORY_TAG mtag_b; + + bli_param_map_char_to_lpmtag( mem_format_a, &mtag_a ); + bli_param_map_char_to_lpmtag( mem_format_b, &mtag_b ); + + // B matrix needs to be packed in a certain format in order to be loaded + // and used in VNNI instrution. As such the mtag_b always needs to be either + // packed or reordered. B matrix as it is (unpacked) cannot be used, and + // the mtag_b is set to packed to enable runtime packing. + if ( mtag_b == UNPACKED ) + { + mtag_b = PACK; + } + + // Only unpacked A supported now. + if ( mtag_a != UNPACKED ) + { + return; // Error. + } + + // Convert post op struct to post op linked list format. + lpgemm_post_op post_op_list[AOCL_MAX_POST_OPS]; + lpgemm_translate_to_post_ops_list + ( + post_op_unparsed, post_op_list, + ( void* )c, ( void* )( &order_use ) + ); + + // Initialize a local runtime with global settings if necessary. Note + // that in the case that a runtime is passed in, we make a local copy. + rntm_t rntm_g; + bli_rntm_init_from_global( &rntm_g ); + bli_membrk_rntm_set_membrk( &rntm_g ); + +#ifdef BLIS_ENABLE_OPENMP + lpgemm_u8s8s32o32_openmp_thread_decorator + ( + m, n, k, + a, rs_a, cs_a, mtag_a, + b, rs_b, cs_b, mtag_b, + ( int32_t* )c, rs_c, cs_c, + alpha, beta, + &rntm_g, + post_op_list, TRUE + ); +#else + lpgemm_u8s8s32o32_thread_decorator + ( + m, n, k, + a, rs_a, cs_a, mtag_a, + b, rs_b, cs_b, mtag_b, + ( int32_t* )c, rs_c, cs_c, + alpha, beta, + &rntm_g, + post_op_list, TRUE + ); +#endif +} diff --git a/addon/aocl_gemm/frame/bf16bf16f32/lpgemm_bf16.c b/addon/aocl_gemm/frame/bf16bf16f32/lpgemm_bf16.c new file mode 100644 index 0000000000..5db523f987 --- /dev/null +++ b/addon/aocl_gemm/frame/bf16bf16f32/lpgemm_bf16.c @@ -0,0 +1,366 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" +#include "lpgemm_5loop_interface_apis.h" +#include "lpgemm_packb_bf16.h" +#include "lpgemm_kernels.h" +#include "lpgemm_utils.h" +#include "lpgemm_thrinfo_utils.h" +#include "lpgemm_config.h" + +// B should always be packed. +LPGEMM_5LOOP(bfloat16,bfloat16,float,bf16bf16f32of32) +{ + dim_t NC = lpgemm_get_block_size_NC_global_cntx( BF16BF16F32OF32 ); + dim_t KC = lpgemm_get_block_size_KC_global_cntx( BF16BF16F32OF32 ); + dim_t MC = lpgemm_get_block_size_MC_global_cntx( BF16BF16F32OF32 ); + dim_t NR = lpgemm_get_block_size_NR_global_cntx( BF16BF16F32OF32 ); + dim_t MR = lpgemm_get_block_size_MR_global_cntx( BF16BF16F32OF32 ); + + const int16_t* a_use = NULL; + dim_t cs_a_use = cs_a; + dim_t a_block_stride = 0; + + const int16_t* b_use = NULL; + dim_t rs_b_use = rs_b; + dim_t cs_b_use = cs_b; + + float* c_use_jc = NULL; + float* c_use_ic = NULL; + dim_t rs_c_use = rs_c; + dim_t rs_c_downscale = rs_c; + + // Pack buffer for B. + bfloat16* pack_b_buffer_bf16; + mem_t mem_b = BLIS_MEM_INITIALIZER; + siz_t mem_b_size_req = 0; + dim_t packb_min_NR = 16; + + // Temporary buffer for C accumulation when downscaling is required. + float* temp_scal_c_buffer_bf16; + mem_t mem_scale_c = BLIS_MEM_INITIALIZER; + siz_t mem_scale_c_size_req = 0; + + // kc needs to be a multiple of 2 so that it can be used with dpbf16_ps + // instruction. Padding is added in cases this condition is not + // satisfied, and therefore the k offset used for packed/reordered + // buffer needs to be updated. + dim_t k_updated = k; + k_updated += (k_updated & 0x1); + + // Is required to decide whether to apply post ops or not. + bool is_last_k = FALSE; + + // Generate thrinfo objects for jc and ic loops from lpgemm_thrinfo_t. + thrinfo_t thread_jc; + thrinfo_t thread_ic; + + lpgemm_gen_thrinfo( thread, &thread_jc, &thread_ic ); + + // Compute the JC, IC loop thread range for the current thread. + dim_t jc_start, jc_end; + bli_thread_range_sub( &thread_jc, n, NR, FALSE, &jc_start, &jc_end ); + + dim_t ic_start, ic_end; + bli_thread_range_sub( &thread_ic, m, MR, FALSE, &ic_start, &ic_end ); + + for ( dim_t jc = jc_start; jc < jc_end; jc += NC ) + { + dim_t nc0 = bli_min( ( jc_end - jc ), NC ); + + dim_t jc_cur_loop = jc; + dim_t jc_cur_loop_rem = 0; + dim_t n_sub_updated; + + if ( mtag_b == REORDERED ) + { + get_B_panel_reordered_start_offset_width + ( + jc, n, NC, packb_min_NR, + &jc_cur_loop, &jc_cur_loop_rem, + &nc0, &n_sub_updated + ); + } + + if ( c_downscale == FALSE ) + { + c_use_jc = c + jc; + } + // Temp accumulaton buffer for C allocation. + else if ( c_downscale == TRUE ) + { + mem_scale_c_size_req = sizeof( float ) * nc0 * ( ic_end - ic_start ); + + lpgemm_alloc_mem_panel + ( + mem_scale_c_size_req, BLIS_BUFFER_FOR_C_PANEL, + &mem_scale_c, rntm + ); + + temp_scal_c_buffer_bf16 = bli_mem_buffer( &mem_scale_c ); + + c_use_jc = ( float* )temp_scal_c_buffer_bf16; + + if ( beta != 0 ) + { + dim_t i_temp = 0; + dim_t j_temp = 0; + int32_t temp_conv_buf = 0; + // Upscale out C to temporary C matrix. + for ( dim_t i_dscale = ic_start; i_dscale < ic_end; ++i_dscale ) + { + j_temp = 0; + for ( dim_t j_dscale = jc; j_dscale < ( jc + nc0 ); ++j_dscale ) + { + // Implemented with the idea sizeof(float)=4. + temp_conv_buf = 0; + temp_conv_buf = *( ( int16_t* )( ( bfloat16* )c + + ( rs_c * i_dscale ) + j_dscale ) ); + + // Add 16 bits in the fractional part. + temp_conv_buf = temp_conv_buf << 16; + + // Store the bytes in float format. + *( temp_scal_c_buffer_bf16 + ( nc0 * i_temp ) + j_temp ) + = *( ( float* )( &temp_conv_buf ) ); + + j_temp++; + } + i_temp++; + } + } + + // The temp c buffer stride is modified as opposed to original C matrix. + rs_c_use = nc0; + } + + for ( dim_t pc = 0; pc < k; pc += KC ) + { + float beta0 = ( pc == 0 ) ? beta : 1; + dim_t kc0 = bli_min( ( k - pc ), KC ); + + // kc0 needs to be a multiple of 2 so that it can be + // used with dpbf16_ps instruction. Padding is added in + // cases this condition is not satisfied, and therefore + // the kc0 offsets used for packed/reordered buffers + // needs to be updated. + dim_t kc0_updated = kc0; + kc0_updated += (kc0_updated & 0x1); + + is_last_k = ( ( pc + KC ) >= k ) ? ( TRUE ) : ( FALSE ); + + if ( mtag_b == PACK ) + { + // Pack B chunks are based on jc work id. + dim_t jc_work_id = bli_thread_work_id( &thread_jc ); + + // Using child thrinfo (thread_ic) tid to decide chief thread + // per B matrix chunk (jc work id group) + if ( bli_thread_am_ochief( &thread_ic ) ) + { + // nc0 needs to be a multiple of 16 since this gives maximum + // vectorization. Packing B always results in buffers with width + // which is a multiple of 16. Subsequently the nc0 offsets used + // for packed/reordered buffers needs to be updated. + dim_t nc0_updated = make_multiple_of_n( nc0, packb_min_NR ); + mem_b_size_req = sizeof( bfloat16 ) * nc0_updated * kc0_updated; + + lpgemm_alloc_mem_panel + ( + mem_b_size_req, BLIS_BUFFER_FOR_B_PANEL, + &mem_b, rntm + ); + + thread->comm[jc_work_id].sent_object = + bli_mem_buffer( &mem_b ); + } + + // All threads in work group should wait till chief thread has + // finished allocating the packing buffers. + bli_thrcomm_barrier + ( + bli_thread_ocomm_id( &thread_ic ), + &thread->comm[jc_work_id] + ); + + pack_b_buffer_bf16 = + ( bfloat16* ) thread->comm[jc_work_id].sent_object; + + // Compute the B panel per thread loop range for parallel + // packing using ic_ways number of threads. Since atmost only + // ic_ways threads can be used, the thread_ic attributes are + // used to split the loop range. + dim_t jc_packb_start, jc_packb_end; + bli_thread_range_sub + ( + &thread_ic, nc0, NR, FALSE, + &jc_packb_start, &jc_packb_end + ); + + // Ensure thread ranges are valid, especially cases where no: + // of threads available for parallelization are greater than + // no: of B panel NR chunks. + if ( ( jc_packb_end > jc_packb_start ) && + ( jc_packb_start < ( jc + nc0 ) ) ) + { +#ifdef BLIS_KERNELS_ZEN4 + packb_nr64_bf16bf16f32of32 + ( + pack_b_buffer_bf16 + ( jc_packb_start * kc0_updated ), + ( b + ( rs_b * pc ) + ( cs_b * jc ) + + ( cs_b * jc_packb_start ) ), rs_b, + ( jc_packb_end - jc_packb_start ), kc0, + &rs_b_use, &cs_b_use + ); +#endif + } + else + { + get_packb_nr64_bf16bf16f32of32_strides( &rs_b_use, &cs_b_use ); + } + + // All threads in work group should wait till B matrix packing + // is completed by the participating threads. + bli_thrcomm_barrier + ( + bli_thread_ocomm_id( &thread_ic ), + &thread->comm[jc_work_id] + ); + b_use = pack_b_buffer_bf16; + } + // B part getting processed + else if ( mtag_b == REORDERED ) + { + // In multi-threaded scenarios, an extra offset into a given + // packed B panel is required, since the jc loop split can + // result in per thread start offset inside the panel, instead + // of panel boundaries. + b_use = b + ( jc_cur_loop * k_updated ) + + ( n_sub_updated * pc ) + + ( jc_cur_loop_rem * kc0_updated ); + + get_packb_nr64_bf16bf16f32of32_strides( &rs_b_use, &cs_b_use ); + } + + for ( dim_t ic = ic_start; ic < ic_end; ic += MC ) + { + dim_t mc0 = bli_min( ( ic_end - ic ), MC ); + + // Only per thread C matrix is stored in temp buffer, so both + // per thread jc and ic start should be normalized to zero. + if ( c_downscale == TRUE ) + { + c_use_ic = c_use_jc + ( rs_c_use * ( ic - ic_start ) ); + } + else + { + c_use_ic = c_use_jc + ( rs_c_use * ic ); + } + + if ( mtag_a == UNPACKED ) + { + a_use = a + ( rs_a * ic ) + ( cs_a * pc ); + + // bf16 kernel reads 2 elements, totalling 4 bytes in a + // single broadcast for use in bf16 instruction. + // Non bf16 based kernel requires update to this code. + cs_a_use = 2; + a_block_stride = rs_a; + } + + for ( dim_t jr = 0; jr < nc0; jr += NR ) + { + dim_t nr0 = bli_min( ( nc0 - jr ), NR ); + +#ifdef BLIS_KERNELS_ZEN4 + // Reorder/Packed B, Reorder/Packed/Unpacked A call. + lpgemm_rowvar_bf16bf16f32of32_6x64 + ( + mc0, nr0, kc0, + a_use, rs_a, cs_a_use, a_block_stride, + ( b_use + ( jr * kc0_updated ) ), rs_b_use, cs_b_use, + ( c_use_ic + jr ), rs_c_use, 1, + alpha, beta0, + is_last_k, ic, ( jc + jr ), post_op_list, rs_c_downscale + ); +#else + // Silence compiler warnings. + ( void )b_use; + ( void )a_block_stride; + ( void )rs_c_downscale; + ( void )is_last_k; + ( void )c_use_ic; + ( void )a_use; + ( void )beta0; + ( void )nr0; + ( void )mc0; + ( void )cs_a_use; +#endif + } + } + } + if ( mtag_b == REORDERED ) + { + adjust_B_panel_reordered_jc( &jc, jc_cur_loop ); + } + } + + // Release pack buffers. + if ( mtag_b == PACK ) + { + // All threads in work group should wait till B matrix usage is + // completed by the participating threads. + bli_thrcomm_barrier + ( + bli_thread_ocomm_id( &thread_jc ), + &thread->comm[bli_thread_work_id( &thread_jc)] + ); + + if ( bli_thread_am_ochief( &thread_ic ) ) + { + if ( bli_mem_is_alloc( &mem_b ) ) + { + bli_membrk_release( rntm, &mem_b ); + } + } + } + if ( c_downscale == TRUE ) + { + if ( bli_mem_is_alloc( &mem_scale_c ) ) + { + bli_membrk_release( rntm, &mem_scale_c ); + } + } +} diff --git a/addon/aocl_gemm/frame/bf16bf16f32/lpgemm_reorder_bf16.c b/addon/aocl_gemm/frame/bf16bf16f32/lpgemm_reorder_bf16.c new file mode 100644 index 0000000000..5bb217facd --- /dev/null +++ b/addon/aocl_gemm/frame/bf16bf16f32/lpgemm_reorder_bf16.c @@ -0,0 +1,180 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" +#include "lpgemm_utils.h" +#include "lpgemm_reorder_bf16.h" +#include "lpgemm_packb_bf16.h" +#include "lpgemm_config.h" +#include "aocl_bf16_type.h" + +void reorderb_nr64_bf16bf16f32of32 + ( + lpgemm_obj_t *b, + lpgemm_obj_t *b_reorder + ) +{ + dim_t NC = lpgemm_get_block_size_NC_global_cntx( BF16BF16F32OF32 ); + dim_t NR = lpgemm_get_block_size_NR_global_cntx( BF16BF16F32OF32 ); + dim_t KC = lpgemm_get_block_size_KC_global_cntx( BF16BF16F32OF32 ); + + // Extracting the matrix properties from the lpgemm object + dim_t rs_b = b->rs; + dim_t n = b->width; + dim_t k = b->length; + + dim_t rs_b_reorder; + dim_t cs_b_reorder; + + // k needs to be a multiple of 2 so that it can be used with dpbf + // instruction. Padding is added in cases this condition is not + // satisfied, and therefore the k offset used for packed/reordered + // buffer needs to be updated. + dim_t k_updated = k; + k_updated += (k_updated & 0x1); + + // Initialize a local runtime with global settings if necessary. Note + // that in the case that a runtime is passed in, we make a local copy. + rntm_t rntm_g; + bli_rntm_init_from_global( &rntm_g ); + + dim_t n_threads = bli_rntm_num_threads( &rntm_g ); + n_threads = ( n_threads > 0 ) ? n_threads : 1; + +#ifdef BLIS_ENABLE_OPENMP + _Pragma( "omp parallel num_threads(n_threads)" ) + { + // Initialise a local thrinfo obj for work split across threads. + thrinfo_t thread_jc; + bli_thrinfo_set_n_way( n_threads, &thread_jc ); + bli_thrinfo_set_work_id( omp_get_thread_num(), &thread_jc ); +#else + { + // Initialise a local thrinfo obj for work split across threads. + thrinfo_t thread_jc; + bli_thrinfo_set_n_way( 1, &thread_jc ); + bli_thrinfo_set_work_id( 0, &thread_jc ); +#endif + // Compute the JC loop thread range for the current thread. + dim_t jc_start, jc_end; + bli_thread_range_sub( &thread_jc, n, NR, FALSE, &jc_start, &jc_end ); + + for ( dim_t jc = jc_start; jc < jc_end; jc += NC ) + { + dim_t nc0 = bli_min( ( jc_end - jc ), NC ); + + dim_t jc_cur_loop = jc; + dim_t jc_cur_loop_rem = 0; + dim_t n_sub_updated; + + get_B_panel_reordered_start_offset_width + ( + jc, n, NC, 16, + &jc_cur_loop, &jc_cur_loop_rem, + &nc0, &n_sub_updated + ); + + for ( dim_t pc = 0; pc < k; pc += KC ) + { + dim_t kc0 = bli_min( ( k - pc ), KC ); + + // k needs to be a multiple of 2 so that it can be used with dpbf + // instruction. Padding is added in cases this condition is not + // satisfied, and therefore the k offset used for packed/reordered + // buffer needs to be updated. + dim_t kc0_updated = kc0; + kc0_updated += (kc0_updated & 0x1); + + // The offsets are calculated in such a way that it resembles + // the reorder buffer traversal in single threaded reordering. + // The panel boundaries (KCxNC) remain as it is accessed in + // single thread, and as a consequence a thread with jc_start + // inside the panel cannot consider NC range for reorder. It + // has to work with NC' < NC, and the offset is calulated using + // prev NC panels spanning k dim + cur NC panel spaning pc loop + // cur iteration + (NC - NC') spanning current kc0 (<= KC). + // + //Eg: Consider the following reordered buffer diagram: + // t1 t2 + // | | + // | |..NC..| + // | | | + // |.NC. |.NC. |NC'|NC" + // pc=0-+-----+-----+---+--+ + // KC| | | | | + // | 1 | 3 | 5 | + // pc=KC-+-----+-----+---st-+ + // KC| | | | | + // | 2 | 4 | 6 | 7| + // pc=k=2KC-+-----+-----+---+--+ + // |jc=0 |jc=NC|jc=2NC| + // + // The numbers 1,2..6,7 denotes the order in which reordered + // KCxNC blocks are stored in memory, ie: block 1 followed by 2 + // followed by 3, etc. Given two threads t1 and t2, and t2 needs + // to acces point st in the reorder buffer to write the data: + // The offset calulation logic will be: + // jc_cur_loop = 2NC, jc_cur_loop_rem = NC', pc = KC, + // n_sub_updated = NC, k = 2KC, kc0_updated = KC + // + // st = ( jc_cur_loop * k ) + // + ( n_sub_updated * pc ) + // + ( NC' * kc0_updated) +#ifdef BLIS_KERNELS_ZEN4 + // B should always be packed. + packb_nr64_bf16bf16f32of32 + ( + ( ( ( bfloat16* )b_reorder->storage.aligned_buffer ) + + ( jc_cur_loop * k_updated ) + ( n_sub_updated * pc ) + + ( jc_cur_loop_rem * kc0_updated ) ), + ( ( ( bfloat16* )b->storage.aligned_buffer ) + + ( rs_b * pc ) + jc ), + rs_b, nc0, kc0, &rs_b_reorder, &cs_b_reorder + ); +#else + // Silence compiler warnings. + rs_b_reorder = 0; + cs_b_reorder = 0; + ( void )rs_b; +#endif + } + + adjust_B_panel_reordered_jc( &jc, jc_cur_loop ); + } + } + + b_reorder->rs = rs_b_reorder; + b_reorder->cs = cs_b_reorder; + b_reorder->mtag = REORDERED; +} diff --git a/addon/aocl_gemm/frame/bf16bf16f32/lpgemm_reorder_bf16.h b/addon/aocl_gemm/frame/bf16bf16f32/lpgemm_reorder_bf16.h new file mode 100644 index 0000000000..c1b83c1b75 --- /dev/null +++ b/addon/aocl_gemm/frame/bf16bf16f32/lpgemm_reorder_bf16.h @@ -0,0 +1,46 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef LPGEMM_REORDER_BF16_H +#define LPGEMM_REORDER_BF16_H + +#include "lpgemm_types.h" + +void reorderb_nr64_bf16bf16f32of32 + ( + lpgemm_obj_t *b, + lpgemm_obj_t *b_reorder + ); + +#endif // LPGEMM_REORDER_H diff --git a/addon/aocl_gemm/frame/f32f32f32/lpgemm_f32f32f32.c b/addon/aocl_gemm/frame/f32f32f32/lpgemm_f32f32f32.c new file mode 100644 index 0000000000..6242ceebe8 --- /dev/null +++ b/addon/aocl_gemm/frame/f32f32f32/lpgemm_f32f32f32.c @@ -0,0 +1,300 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" +#include "lpgemm_5loop_interface_apis.h" +#include "lpgemm_types.h" +#include "lpgemm_utils.h" +#include "lpgemm_thrinfo_utils.h" + +void lpgemm_pack_a_f32f32f32of32 + ( + const float* input_buf_addr_a, + float* reorder_buf_addr_a, + const dim_t m, + const dim_t k, + const dim_t rs_a, + const dim_t cs_a, + const dim_t ps_p, + const dim_t MR, + cntx_t* cntx + ); + +LPGEMM_5LOOP(float,float,float,f32f32f32of32) +{ + // Query the global cntx. + cntx_t* cntx = bli_gks_query_cntx(); + + num_t dt = BLIS_FLOAT; + + // Query the context for various blocksizes. + const dim_t NR = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_NR, cntx ); + const dim_t MR = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_MR, cntx ); + const dim_t NC = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_NC, cntx ); + const dim_t MC = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_MC, cntx ); + const dim_t KC = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_KC, cntx ); + + // Strides are updated based on matrix packing/reordering. + const float* a_use = NULL; + dim_t rs_a_use = rs_a; + dim_t cs_a_use = cs_a; + + const float* b_use = NULL; + dim_t rs_b_use = rs_b; + dim_t cs_b_use = cs_b; + + float* c_use_jc = NULL; + float* c_use_ic = NULL; + + // Only supporting row major with unit column strided C for now. + const dim_t cs_c_use = 1; + + /* Compute partitioning step values for each matrix of each loop. */ + inc_t ps_a_use; + inc_t ps_b_use; + auxinfo_t aux; + + // Check if packing of A is required. + bool should_pack_A = bli_rntm_pack_a( rntm ); + + // Pack buffer for A. + float* pack_a_buffer_f32f32f32of32; + mem_t mem_a = BLIS_MEM_INITIALIZER; + siz_t mem_a_size_req = 0; + + float one_local = *PASTEMAC(s,1); + + trans_t transc = BLIS_NO_TRANSPOSE; + conj_t conjc = bli_extract_conj( transc ); + + // Generate thrinfo objects for jc and ic loops from lpgemm_thrinfo_t. + thrinfo_t thread_jc; + thrinfo_t thread_ic; + + lpgemm_gen_thrinfo( thread, &thread_jc, &thread_ic ); + + // Compute the JC loop thread range for the current thread. + dim_t jc_start, jc_end; + bli_thread_range_sub( &thread_jc, n, NR, FALSE, &jc_start, &jc_end ); + + for ( dim_t jc = jc_start; jc < jc_end; jc += NC ) + { + dim_t nc0 = bli_min( ( jc_end - jc ), NC ); + c_use_jc = c + jc; + + dim_t jc_cur_loop = jc; + dim_t jc_cur_loop_rem = 0; + dim_t n_sub_updated; + + if ( mtag_b == REORDERED ) + { + get_B_panel_reordered_start_offset_width + ( + jc, n, NC, NR, + &jc_cur_loop, &jc_cur_loop_rem, + &nc0, &n_sub_updated + ); + } + + for ( dim_t pc = 0; pc < k; pc += KC ) + { + float beta0 = ( pc == 0 ) ? beta : one_local; + dim_t kc0 = bli_min( ( k - pc ), KC ); + + if ( mtag_b == REORDERED ) + { + // In multi-threaded scenarios, an extra offset into a given + // packed B panel is required, since the jc loop split can + // result in per thread start offset inside the panel, instead + // of panel boundaries. + b_use = b + ( jc_cur_loop * k ) + + ( n_sub_updated * pc ) + ( jc_cur_loop_rem * kc0 ); + + rs_b_use = NR; + cs_b_use = 1; + ps_b_use = kc0; + } + else + { + b_use = b + ( pc * rs_b ) + ( jc * cs_b ); + ps_b_use = 1; + } + + dim_t ic_start, ic_end; + bli_thread_range_sub( &thread_ic, m, MR, FALSE, &ic_start, &ic_end ); + + for ( dim_t ic = ic_start; ic < ic_end; ic += MC ) + { + dim_t mc0 = bli_min( ( ic_end - ic ), MC ); + c_use_ic = c_use_jc + ( rs_c * ic ); + + if ( mtag_a == REORDERED ) + { + // Extra space since packing does width in multiples of MR. + const dim_t m_updated = ( ( m + MR - 1 ) / MR ) * MR; + a_use = a + ( pc * m_updated ) + ( kc0 * ic ); + + rs_a_use = 1; + cs_a_use = MR; + ps_a_use = MR * kc0; + } + else if ( should_pack_A == TRUE ) + { + // Extra space since packing does width in multiples of MR. + const dim_t mc0_updated = ( ( mc0 + MR - 1 ) / MR ) * MR; + mem_a_size_req = sizeof( float ) * mc0_updated * kc0; + + lpgemm_alloc_mem_panel + ( + mem_a_size_req, BLIS_BUFFER_FOR_A_BLOCK, + &mem_a, rntm + ); + pack_a_buffer_f32f32f32of32 = ( float* )bli_mem_buffer( &mem_a ); + + rs_a_use = 1; + cs_a_use = MR; + ps_a_use = MR * kc0; + + lpgemm_pack_a_f32f32f32of32 + ( + ( a + ( rs_a * ic ) + pc ), + pack_a_buffer_f32f32f32of32, + mc0, kc0, + rs_a, cs_a, ps_a_use, MR, + cntx + ); + + a_use = pack_a_buffer_f32f32f32of32; + } + else + { + a_use = a + ( rs_a * ic ) + pc; + ps_a_use = MR * rs_a; + } + + // Embed the panel stride of A within the auxinfo_t object. The + // millikernel will query and use this to iterate through + // micropanels of A (if needed). + bli_auxinfo_set_ps_a( ps_a_use, &aux ); + + for ( dim_t jr = 0; jr < nc0; jr += NR ) + { + dim_t nr0 = bli_min( ( nc0 - jr ), NR ); + + // Reordered/unpacked B, reordered/unpacked A. + bli_sgemmsup_rv_zen_asm_6x16m + ( + conjc, + conjc, + mc0, nr0, kc0, + &alpha, + ( float* )a_use, rs_a_use, cs_a_use, + ( float* )( b_use + ( jr * ps_b_use ) ), rs_b_use, cs_b_use, + &beta0, + ( c_use_ic + jr ), rs_c, cs_c_use, + &aux, cntx + ); + } + } + } + if ( mtag_b == REORDERED ) + { + adjust_B_panel_reordered_jc( &jc, jc_cur_loop ); + } + } + + // Release pack buffers. + if ( should_pack_A == TRUE ) + { + if ( bli_mem_is_alloc( &mem_a ) ) + { + bli_membrk_release( rntm, &mem_a ); + } + } +} + +void lpgemm_pack_a_f32f32f32of32 + ( + const float* input_buf_addr_a, + float* reorder_buf_addr_a, + const dim_t m, + const dim_t k, + const dim_t rs_a, + const dim_t cs_a, + const dim_t ps_p, + const dim_t MR, + cntx_t* cntx + ) +{ + float one_local = *PASTEMAC(s,1); + float* restrict kappa_cast = &one_local; + + // Set the schema to "column stored row panels" to indicate packing to conventional + // column-stored row panels. + pack_t schema = BLIS_PACKED_ROW_PANELS; + trans_t transc = BLIS_NO_TRANSPOSE; + conj_t conjc = bli_extract_conj( transc ); + + // Compute the total number of iterations we'll need. + dim_t m_iter = ( m + MR - 1 ) / MR; + + inc_t cs_p = MR; + + float* p_temp = reorder_buf_addr_a; + dim_t ir, it; + // Iterate over every logical micropanel in the source matrix. + for ( ir = 0, it = 0; it < m_iter; ir += MR, it += 1 ) + { + dim_t panel_dim_i = bli_min( MR, m - ir ); + + const float* a_use = input_buf_addr_a + ( ir * rs_a ); + float* p_use = p_temp; + + PASTEMAC(s,packm_cxk) + ( + conjc, + schema, + panel_dim_i, + MR, + k, + k, + kappa_cast, + ( float* )a_use, rs_a, cs_a, + p_use, cs_p, + cntx + ); + + p_temp += ps_p; + } +} diff --git a/addon/aocl_gemm/frame/lpgemm_5loop_interface_apis.h b/addon/aocl_gemm/frame/lpgemm_5loop_interface_apis.h new file mode 100644 index 0000000000..45328669de --- /dev/null +++ b/addon/aocl_gemm/frame/lpgemm_5loop_interface_apis.h @@ -0,0 +1,71 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef LPGEMM_5LOOP_INTF_H +#define LPGEMM_5LOOP_INTF_H + +#include "lpgemm_types.h" +#include "lpgemm_post_ops.h" +#include "aocl_bf16_type.h" + +#define LPGEMM_5LOOP(A_type,B_type,C_type,LP_SFX) \ +void lpgemm_rowvar_ ## LP_SFX \ + ( \ + const dim_t m, \ + const dim_t n, \ + const dim_t k, \ + const A_type* a, \ + const dim_t rs_a, \ + const dim_t cs_a, \ + const AOCL_MEMORY_TAG mtag_a, \ + const B_type* b, \ + const dim_t rs_b, \ + const dim_t cs_b, \ + const AOCL_MEMORY_TAG mtag_b, \ + C_type* c, \ + const dim_t rs_c, \ + const dim_t cs_c, \ + C_type alpha, \ + C_type beta, \ + rntm_t* rntm, \ + lpgemm_thrinfo_t* thread, \ + lpgemm_post_op* post_op_list, \ + bool c_downscale \ + ) \ + +LPGEMM_5LOOP(uint8_t,int8_t,int32_t,u8s8s32o32); +LPGEMM_5LOOP(uint8_t,int8_t,int16_t,u8s8s16o16); +LPGEMM_5LOOP(float,float,float,f32f32f32of32); +LPGEMM_5LOOP(bfloat16,bfloat16,float,bf16bf16f32of32); +#endif // LPGEMM_5LOOP_INTF_H diff --git a/addon/aocl_gemm/frame/lpgemm_config.c b/addon/aocl_gemm/frame/lpgemm_config.c new file mode 100644 index 0000000000..901ec087d2 --- /dev/null +++ b/addon/aocl_gemm/frame/lpgemm_config.c @@ -0,0 +1,90 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" +#include "lpgemm_config.h" + +lpgemm_cntx_t global_cntx_t_list[4]; //Only one op type supported now. + +BLIS_INLINE void lpgemm_set_block_sizes_global_cntx + ( + AOCL_OPERATION_TYPE op_type, + dim_t MC, + dim_t NC, + dim_t KC, + dim_t NR, + dim_t MR + ) +{ + global_cntx_t_list[op_type].blksz.MC = MC; + global_cntx_t_list[op_type].blksz.NC = NC; + global_cntx_t_list[op_type].blksz.KC = KC; + global_cntx_t_list[op_type].blksz.NR = NR; + global_cntx_t_list[op_type].blksz.MR = MR; +} + +// Sets default block sizes for lpgemm. Currently only u8s8s32 supported. +// Thread safety is not considered now since the block sizes are not expected +// to be configurable from application. +void aocl_lpgemm_init_global_cntx() +{ + lpgemm_set_block_sizes_global_cntx( U8S8S32OS32, 144, 1024, 2048, 64, 6 ); + lpgemm_set_block_sizes_global_cntx( U8S8S16OS16, 144, 1024, 1024, 32, 6 ); + lpgemm_set_block_sizes_global_cntx( BF16BF16F32OF32, 144, 1024, 2048, 64, 6 ); +} + +dim_t lpgemm_get_block_size_MC_global_cntx( AOCL_OPERATION_TYPE op_type ) +{ + return global_cntx_t_list[op_type].blksz.MC; +} + +dim_t lpgemm_get_block_size_NC_global_cntx( AOCL_OPERATION_TYPE op_type ) +{ + return global_cntx_t_list[op_type].blksz.NC; +} + +dim_t lpgemm_get_block_size_KC_global_cntx( AOCL_OPERATION_TYPE op_type ) +{ + return global_cntx_t_list[op_type].blksz.KC; +} + +dim_t lpgemm_get_block_size_NR_global_cntx( AOCL_OPERATION_TYPE op_type ) +{ + return global_cntx_t_list[op_type].blksz.NR; +} + +dim_t lpgemm_get_block_size_MR_global_cntx( AOCL_OPERATION_TYPE op_type ) +{ + return global_cntx_t_list[op_type].blksz.MR; +} diff --git a/addon/aocl_gemm/frame/lpgemm_config.h b/addon/aocl_gemm/frame/lpgemm_config.h new file mode 100644 index 0000000000..7e7f3bb2ad --- /dev/null +++ b/addon/aocl_gemm/frame/lpgemm_config.h @@ -0,0 +1,55 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef LPGEMM_CONFIG_H +#define LPGEMM_CONFIG_H + +#include "lpgemm_types.h" + +// equals to number of ops in enum AOCL_OPERATION_TYPE. +extern lpgemm_cntx_t lpgemm_global_cntx_t_list[4]; + +void aocl_lpgemm_init_global_cntx(); + +dim_t lpgemm_get_block_size_MC_global_cntx( AOCL_OPERATION_TYPE op_type ); + +dim_t lpgemm_get_block_size_NC_global_cntx( AOCL_OPERATION_TYPE op_type ); + +dim_t lpgemm_get_block_size_KC_global_cntx( AOCL_OPERATION_TYPE op_type ); + +dim_t lpgemm_get_block_size_NR_global_cntx( AOCL_OPERATION_TYPE op_type ); + +dim_t lpgemm_get_block_size_MR_global_cntx( AOCL_OPERATION_TYPE op_type ); + +#endif //LPGEMM_CONFIG_H diff --git a/addon/aocl_gemm/frame/lpgemm_post_ops.c b/addon/aocl_gemm/frame/lpgemm_post_ops.c new file mode 100644 index 0000000000..63fb25765f --- /dev/null +++ b/addon/aocl_gemm/frame/lpgemm_post_ops.c @@ -0,0 +1,155 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" +#include "lpgemm_post_ops.h" + +BLIS_INLINE void lpgemm_set_node_params + ( + lpgemm_post_op* post_op_node, + LPGEMM_POST_OP_CODE op_code, + void* op1, + void* op2, + void* op3, + void* scale_factor, + bool is_power_of_2 + ) +{ + post_op_node->op_code = op_code; + post_op_node->op_args1 = op1; + post_op_node->op_args2 = op2; + post_op_node->op_args3 = op3; + post_op_node->scale_factor = scale_factor; + post_op_node->is_power_of_2 = is_power_of_2; + post_op_node->next = NULL; +} + +void lpgemm_translate_to_post_ops_list + ( + aocl_post_op* post_op_unparsed, + lpgemm_post_op* post_op_list, + void* scale_buffer, + void* meta_arg + ) +{ + if ( post_op_unparsed == NULL ) + { + lpgemm_set_node_params + ( + post_op_list, POST_OPS_DISABLE, + NULL, NULL, NULL, NULL, FALSE + ); + return; + } + + if ( ( post_op_unparsed->seq_length > AOCL_MAX_POST_OPS ) ) + { + lpgemm_set_node_params + ( + post_op_list, POST_OPS_DISABLE, + NULL, NULL, NULL, NULL, FALSE + ); + return; //Error, seq length exceeds max post ops permitted. + } + + for ( dim_t i = 0; i < post_op_unparsed->seq_length; ++i ) + { + // Dispatcher code + switch ( *( post_op_unparsed->seq_vector + i ) ) + { + case SUM: + lpgemm_set_node_params + ( + ( post_op_list + i ), POST_OPS_SUM, + post_op_unparsed->sum.buff, + post_op_unparsed->sum.zero_point, + NULL, + post_op_unparsed->sum.scale_factor, + post_op_unparsed->sum.is_power_of_2 + ); + break; + case ELTWISE: + { + LPGEMM_POST_OP_CODE tmp_code = POST_OPS_DISABLE; + // Eltwise algo dispatcher. + switch ( post_op_unparsed->eltwise.algo.algo_type ) + { + case RELU: + tmp_code = POST_OPS_RELU; + break; + case PRELU: + tmp_code = POST_OPS_RELU_SCALE; + break; + default: + break; + } + lpgemm_set_node_params + ( + ( post_op_list + i ), tmp_code, + NULL, + post_op_unparsed->eltwise.algo.alpha, + post_op_unparsed->eltwise.algo.beta, + post_op_unparsed->eltwise.scale_factor, + post_op_unparsed->eltwise.is_power_of_2 + ); + } + break; + case BIAS: + lpgemm_set_node_params + ( + ( post_op_list + i ), POST_OPS_BIAS, + post_op_unparsed->bias.bias, + meta_arg, NULL, NULL, FALSE + ); + break; + case SCALE: + lpgemm_set_node_params + ( + ( post_op_list + i ), POST_OPS_DOWNSCALE, + post_op_unparsed->sum.zero_point, + meta_arg, scale_buffer, + post_op_unparsed->sum.scale_factor, FALSE + ); + break; + default: + break; + } + + // Simulating linked link using an array. + if ( i < ( post_op_unparsed->seq_length - 1 ) ) + { + ( post_op_list + i )->next = ( post_op_list + i + 1); + } + } +} diff --git a/addon/aocl_gemm/frame/lpgemm_post_ops.h b/addon/aocl_gemm/frame/lpgemm_post_ops.h new file mode 100644 index 0000000000..3932daf602 --- /dev/null +++ b/addon/aocl_gemm/frame/lpgemm_post_ops.h @@ -0,0 +1,89 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef LPGEMM_POST_OPS_H +#define LPGEMM_POST_OPS_H + +typedef enum +{ + POST_OPS_DISABLE = 0, + POST_OPS_BIAS = 1, + POST_OPS_RELU = 2, + POST_OPS_RELU_SCALE = 3, + POST_OPS_DOWNSCALE = 4, + POST_OPS_SUM = 5, +} LPGEMM_POST_OP_CODE; + +// Used as an internal structure. +typedef struct lpgemm_post_op_t +{ + LPGEMM_POST_OP_CODE op_code; + void* op_args1; + void* op_args2; // alpha, zero_point, storage order + void* op_args3; // beta, downscale buffer/original C matrix + void* scale_factor; + bool is_power_of_2; + struct lpgemm_post_op_t* next; +} lpgemm_post_op; + +void lpgemm_translate_to_post_ops_list + ( + aocl_post_op* post_op_unparsed, + lpgemm_post_op* post_op_list, + void* scale_buffer, + void* meta_arg + ); + +#define POST_OP_LABEL_LASTK_SAFE_JUMP \ + if ( ( is_last_k == TRUE ) && ( post_ops_list_temp != NULL ) ) \ + { \ + goto *post_ops_labels[post_ops_list_temp->op_code]; \ + } \ + else \ + { \ + goto *post_ops_labels[0]; \ + } + +#define POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR \ + post_ops_list_temp = post_ops_list_temp->next; \ + if ( post_ops_list_temp != NULL ) \ + { \ + goto *post_ops_labels[post_ops_list_temp->op_code]; \ + } \ + else \ + { \ + goto *post_ops_labels[0]; \ + } + +#endif //LPGEMM_POST_OPS_H diff --git a/addon/aocl_gemm/frame/lpgemm_types.h b/addon/aocl_gemm/frame/lpgemm_types.h new file mode 100644 index 0000000000..aebd485d0d --- /dev/null +++ b/addon/aocl_gemm/frame/lpgemm_types.h @@ -0,0 +1,117 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef LPGEMM_TYPES_H +#define LPGEMM_TYPES_H + +typedef enum +{ + INT8 = 0, + INT16 = 1, + INT32 = 2 +} AOCL_ARRAY_TYPE; + +// Enum name template:A_mat_type ## B_mat_type ## Accumulate_type ## C_mat_type. +typedef enum +{ + U8S8S16OS16 = 0, // uint8_t - A, int8_t - B, int16_t - C + U8S8S32OS32 = 1, // uint8_t - A, int8_t - B, int32_t - C + F16F16F16OF16 = 2, // float16 - A, float16 - B, float16 - C + BF16BF16F32OF32 = 3 // bf16 - A, bf16 - B, float - C +} AOCL_OPERATION_TYPE; + +typedef enum +{ + UNPACKED = 0, + PACK = 1, + REORDERED = 2, +} AOCL_MEMORY_TAG; + +typedef enum +{ + ROW_MAJOR = 0, + COLUMN_MAJOR = 1, +} AOCL_STOR_TAG; + +typedef enum +{ + A_MATRIX = 0, + B_MATRIX = 1, +} AOCL_MATRIX_TYPE; + +typedef struct +{ + void* aligned_buffer; + void* origin_buffer; +} lpgemm_mem_t; + +typedef struct +{ + dim_t length; + dim_t width; + + dim_t elem_size; + + dim_t rs; + dim_t cs; + + AOCL_MEMORY_TAG mtag; + + lpgemm_mem_t storage; +} lpgemm_obj_t; + +typedef struct +{ + dim_t MC; + dim_t NC; + dim_t KC; + dim_t NR; + dim_t MR; +} lpgemm_block_size_t; + +typedef struct +{ + lpgemm_block_size_t blksz; +} lpgemm_cntx_t; + +typedef struct +{ + dim_t n_threads; + dim_t tid; + dim_t ic_ways; + dim_t jc_ways; + thrcomm_t* comm; +} lpgemm_thrinfo_t; + +#endif //LPGEMM_TYPES_H diff --git a/addon/aocl_gemm/frame/threading/lpgemm_thread_decor_openmp.c b/addon/aocl_gemm/frame/threading/lpgemm_thread_decor_openmp.c new file mode 100644 index 0000000000..0c1df5e7c3 --- /dev/null +++ b/addon/aocl_gemm/frame/threading/lpgemm_thread_decor_openmp.c @@ -0,0 +1,634 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" +#include "lpgemm_config.h" +#include "lpgemm_thread_decor_openmp.h" +#include "lpgemm_types.h" +#include "lpgemm_5loop_interface_apis.h" + +#ifdef BLIS_ENABLE_OPENMP + +#define BLIS_LPGEMM_NUM_STATIC_COMMS 96 + +BLIS_INLINE dim_t next_factor + ( + const dim_t nt, + const dim_t part_nt + ) +{ + if ( part_nt == nt ) + { + return part_nt; + } + + dim_t nt_temp = part_nt + 1; + while ( ( nt_temp <= nt ) && ( ( nt % nt_temp ) != 0 ) ) + { + nt_temp++; + } + return nt_temp; +} + +BLIS_INLINE dim_t prev_factor + ( + const dim_t nt, + const dim_t part_nt + ) +{ + if ( part_nt == 1 ) + { + return part_nt; + } + + dim_t nt_temp = part_nt - 1; + while ( ( nt_temp >= 1 ) && ( ( nt % nt_temp ) != 0 ) ) + { + nt_temp--; + } + return nt_temp; +} + +BLIS_INLINE void lpgemm_pnl_wrk_heur_adjust_ic_jc_ways + ( + dim_t MR, + dim_t NR, + dim_t m, + dim_t n, + dim_t* n_threads, + dim_t* ic_ways, + dim_t* jc_ways + ) +{ + // This function currently only increments ic and subsequently decrements + // jc. Cannot proceed if all threads are allocated to ic. + // The factorization adjustment here is based on improving the B NR panel + // distribution among the jc threads. + dim_t mu = ( m + MR - 1 ) / MR; + dim_t nu = ( n + NR - 1 ) / NR; + + // The next 3 ic factors will be considered to see if it results in better + // NR panel distribution and subsequently reduce the per thread panel work. + dim_t nu_mod_jc_ways = nu % ( *jc_ways ); + if ( ( nu_mod_jc_ways != 0 ) && ( ( *ic_ways ) < ( *n_threads ) ) ) + { + dim_t mu_ic_cur = ( mu + ( *ic_ways ) - 1 ) / ( *ic_ways ); + dim_t nu_jc_cur = ( nu + ( *jc_ways ) - 1 ) / ( *jc_ways ); + dim_t panel_work_cur = mu_ic_cur + nu_jc_cur; + + const dim_t next_ic = next_factor( ( *n_threads ), ( *ic_ways ) ); + const dim_t prev_jc = prev_factor( ( *n_threads ), ( *jc_ways ) ); + dim_t mu_ic_next = ( mu + next_ic - 1 ) / next_ic; + dim_t nu_jc_prev = ( nu + prev_jc - 1 ) / prev_jc; + dim_t panel_work_next = mu_ic_next + nu_jc_prev; + + if ( panel_work_next < panel_work_cur ) + { + panel_work_cur = panel_work_next; + ( *ic_ways ) = next_ic; + ( *jc_ways ) = prev_jc; + } + + nu_mod_jc_ways = nu % ( *jc_ways ); + if ( ( nu_mod_jc_ways != 0 ) && ( next_ic < ( *n_threads ) ) ) + { + const dim_t next_next_ic = next_factor( ( *n_threads ), next_ic ); + const dim_t prev_prev_jc = prev_factor( ( *n_threads ), prev_jc ); + dim_t mu_ic_next_next = ( mu + next_next_ic - 1 ) / next_next_ic; + dim_t nu_jc_prev_prev = ( nu + prev_prev_jc - 1 ) / prev_prev_jc; + dim_t panel_work_next_next = mu_ic_next_next + nu_jc_prev_prev; + + if ( panel_work_next_next < panel_work_cur ) + { + panel_work_cur = panel_work_next_next; + ( *ic_ways ) = next_next_ic; + ( *jc_ways ) = prev_prev_jc; + } + + nu_mod_jc_ways = nu % ( *jc_ways ); + if ( ( nu_mod_jc_ways != 0 ) && ( next_next_ic < ( *n_threads ) ) ) + { + const dim_t next_next_next_ic = + next_factor + ( + ( *n_threads ), next_next_ic + ); + const dim_t prev_prev_prev_jc = + prev_factor + ( + ( *n_threads ), prev_prev_jc + ); + dim_t mu_ic_next_next_next = + ( mu + next_next_next_ic - 1 ) / next_next_next_ic; + dim_t nu_jc_prev_prev_prev = + ( nu + prev_prev_prev_jc - 1 ) / prev_prev_prev_jc; + dim_t panel_work_next_next_next = + mu_ic_next_next_next + nu_jc_prev_prev_prev; + + if ( panel_work_next_next_next < panel_work_cur ) + { + ( *ic_ways ) = next_next_next_ic; + ( *jc_ways ) = prev_prev_prev_jc; + } + } + } + } +} + +BLIS_INLINE void lpgemm_adjust_ic_jc_ways + ( + dim_t m, + dim_t n, + dim_t* n_threads, + dim_t* ic_ways, + dim_t* jc_ways + ) +{ + const dim_t m_ic = m / ( *ic_ways ); + const dim_t n_jc = n / ( *jc_ways ); + const int64_t cur_work_per_thread = m_ic + n_jc; + + const dim_t next_ic = next_factor( ( *n_threads ), ( *ic_ways ) ); + const dim_t prev_ic = prev_factor( ( *n_threads ), ( *ic_ways ) ); + const dim_t next_jc = next_factor( ( *n_threads ), ( *jc_ways ) ); + const dim_t prev_jc = prev_factor( ( *n_threads ), ( *jc_ways ) ); + + const dim_t m_next_ic = m / next_ic; + const dim_t m_prev_ic = m / prev_ic; + const dim_t n_next_jc = n / next_jc; + const dim_t n_prev_jc = n / prev_jc; + + const int64_t next_jc_work_per_thread = n_next_jc + m_prev_ic; + const int64_t next_ic_work_per_thread = m_next_ic + n_prev_jc; + + bool can_increase_ic = FALSE; + bool can_increase_jc = FALSE; + + if ( next_ic_work_per_thread <= cur_work_per_thread ) + { + can_increase_ic = TRUE; + } + else if ( next_jc_work_per_thread < cur_work_per_thread ) + { + can_increase_jc = TRUE; + } + + if ( can_increase_ic ) + { + ( *ic_ways ) = next_ic; + ( *jc_ways ) = prev_jc; + } + else if ( can_increase_jc ) + { + // Giving priority to ic and m dimensions, if m >= n, jc must be < ic. + if ( ( ( m >= n ) && ( prev_ic >= next_jc ) ) || + ( ( m < n ) && ( prev_ic <= next_jc ) ) ) + { + ( *ic_ways ) = prev_ic; + ( *jc_ways ) = next_jc; + } + } +} + +BLIS_INLINE void lpgemm_u8s8s16o16_get_threading + ( + dim_t* n_threads, + dim_t* ic_ways, + dim_t* jc_ways, + dim_t m, + dim_t n, + dim_t k, + rntm_t* rntm_g + ) +{ + *n_threads = bli_rntm_num_threads( rntm_g ); + *jc_ways = bli_rntm_jc_ways( rntm_g ); + *ic_ways = bli_rntm_ic_ways( rntm_g ); + + if ( ( ( *ic_ways ) > 0 ) || ( ( *jc_ways ) > 0 ) ) + { + // If BLIS_IC_NT or JC_NT are set. + // Default cases. + *ic_ways = ( ( *ic_ways ) > 0 ) ? ( *ic_ways ) : 1; + *jc_ways = ( ( *jc_ways ) > 0 ) ? ( *jc_ways ) : 1; + + *n_threads = ( *jc_ways ) * ( *ic_ways ); + } + else if ( ( *n_threads ) > 1 ) + { + + dim_t NR = lpgemm_get_block_size_NR_global_cntx( U8S8S16OS16 ); + + if ( n <= NR ) + { + // If n is less than micro panel dimension, allocating all threads + // to ic resulted in gains. + ( *ic_ways ) = ( *n_threads ); + ( *jc_ways ) = 1; + } + else + { + // If BLIS_NUM_THREADS are set, generate jc,ic from the same. + bli_thread_partition_2x2( ( *n_threads ), m, n, ic_ways, jc_ways ); + } + } + else + { + // Setting all the values to 1 in case n_threads <= 1. This ensures + // the threading parameters are valid. + *n_threads = 1; + *jc_ways = 1; + *ic_ways = 1; + } +} + +BLIS_INLINE void lpgemm_u8s8s32o32_get_threading + ( + dim_t* n_threads, + dim_t* ic_ways, + dim_t* jc_ways, + dim_t m, + dim_t n, + dim_t k, + rntm_t* rntm_g + ) +{ + *n_threads = bli_rntm_num_threads( rntm_g ); + *jc_ways = bli_rntm_jc_ways( rntm_g ); + *ic_ways = bli_rntm_ic_ways( rntm_g ); + + if ( ( ( *ic_ways ) > 0 ) || ( ( *jc_ways ) > 0 ) ) + { + // If BLIS_IC_NT or JC_NT are set. + // Default cases. + *ic_ways = ( ( *ic_ways ) > 0 ) ? ( *ic_ways ) : 1; + *jc_ways = ( ( *jc_ways ) > 0 ) ? ( *jc_ways ) : 1; + + *n_threads = ( *jc_ways ) * ( *ic_ways ); + } + else if ( ( *n_threads ) > 1 ) + { + + dim_t NR = lpgemm_get_block_size_NR_global_cntx( U8S8S32OS32 ); + dim_t MR = lpgemm_get_block_size_MR_global_cntx( U8S8S32OS32 ); + + if ( n <= NR ) + { + // If n is less than micro panel dimension, allocating all threads + // to ic resulted in gains. + ( *ic_ways ) = ( *n_threads ); + ( *jc_ways ) = 1; + } + else + { + // If BLIS_NUM_THREADS are set, generate jc,ic from the same. + bli_thread_partition_2x2( ( *n_threads ), m, n, ic_ways, jc_ways ); + + lpgemm_adjust_ic_jc_ways( m, n, n_threads, ic_ways, jc_ways ); + + lpgemm_pnl_wrk_heur_adjust_ic_jc_ways + ( + MR, NR, m, n, + n_threads, ic_ways, jc_ways + ); + } + } + else + { + // Setting all the values to 1 in case n_threads <= 1. This ensures + // the threading parameters are valid. + *n_threads = 1; + *jc_ways = 1; + *ic_ways = 1; + } +} + +BLIS_INLINE void lpgemm_bf16bf16f32of32_get_threading + ( + dim_t* n_threads, + dim_t* ic_ways, + dim_t* jc_ways, + dim_t m, + dim_t n, + dim_t k, + rntm_t* rntm_g + ) +{ + *n_threads = bli_rntm_num_threads( rntm_g ); + *jc_ways = bli_rntm_jc_ways( rntm_g ); + *ic_ways = bli_rntm_ic_ways( rntm_g ); + + if ( ( ( *ic_ways ) > 0 ) || ( ( *jc_ways ) > 0 ) ) + { + // If BLIS_IC_NT or JC_NT are set. + // Default cases. + *ic_ways = ( ( *ic_ways ) > 0 ) ? ( *ic_ways ) : 1; + *jc_ways = ( ( *jc_ways ) > 0 ) ? ( *jc_ways ) : 1; + + *n_threads = ( *jc_ways ) * ( *ic_ways ); + } + else if ( ( *n_threads ) > 1 ) + { + + dim_t NR = lpgemm_get_block_size_NR_global_cntx( BF16BF16F32OF32 ); + dim_t MR = lpgemm_get_block_size_MR_global_cntx( BF16BF16F32OF32 ); + + if ( n <= NR ) + { + // If n is less than micro panel dimension, allocating all threads + // to ic resulted in gains. + ( *ic_ways ) = ( *n_threads ); + ( *jc_ways ) = 1; + } + else + { + // If BLIS_NUM_THREADS are set, generate jc,ic from the same. + bli_thread_partition_2x2( ( *n_threads ), m, n, ic_ways, jc_ways ); + lpgemm_adjust_ic_jc_ways( m, n, n_threads, ic_ways, jc_ways ); + lpgemm_pnl_wrk_heur_adjust_ic_jc_ways + ( + MR, NR, m, n, + n_threads, ic_ways, jc_ways + ); + } + } + else + { + // Setting all the values to 1 in case n_threads <= 1. This ensures + // the threading parameters are valid. + *n_threads = 1; + *jc_ways = 1; + *ic_ways = 1; + } +} + +// Some aspects of sgemm smart threading incorporated here. Eventually this +// will be redirected to the sgemm smart threading API. +BLIS_INLINE void lpgemm_f32f32f32of32_get_threading + ( + dim_t* n_threads, + dim_t* ic_ways, + dim_t* jc_ways, + dim_t m, + dim_t n, + dim_t k, + rntm_t* rntm_g + ) +{ + // Query the global cntx. + cntx_t* cntx = bli_gks_query_cntx(); + + num_t dt = BLIS_FLOAT; + + // Query the context for SUP limits. + const dim_t MT = bli_cntx_get_l3_sup_thresh_dt( dt, BLIS_MT, cntx ); + const dim_t NT = bli_cntx_get_l3_sup_thresh_dt( dt, BLIS_NT, cntx ); + const dim_t KT = bli_cntx_get_l3_sup_thresh_dt( dt, BLIS_KT, cntx ); + + const dim_t MT_2 = MT / 2; + + *n_threads = bli_rntm_num_threads( rntm_g ); + *jc_ways = bli_rntm_jc_ways( rntm_g ); + *ic_ways = bli_rntm_ic_ways( rntm_g ); + + if ( ( ( *ic_ways ) > 0 ) || ( ( *jc_ways ) > 0 ) ) + { + // If BLIS_IC_NT or JC_NT are set. + // Default cases. + *ic_ways = ( ( *ic_ways ) > 0 ) ? ( *ic_ways ) : 1; + *jc_ways = ( ( *jc_ways ) > 0 ) ? ( *jc_ways ) : 1; + + *n_threads = ( *jc_ways ) * ( *ic_ways ); + } + else if ( ( *n_threads ) > 1 ) + { + // If BLIS_NUM_THREADS are set, generate jc,ic from the same. + bli_thread_partition_2x2( ( *n_threads ), m, n, ic_ways, jc_ways ); + + lpgemm_adjust_ic_jc_ways( m, n, n_threads, ic_ways, jc_ways ); + } + else + { + // Setting all the values to 1 in case n_threads <= 1. This ensures + // the threading parameters are valid. + *n_threads = 1; + *jc_ways = 1; + *ic_ways = 1; + } + + // Native -> SUP path. + const dim_t m_ic = m / ( *ic_ways ); + const dim_t n_jc = n / ( *jc_ways ); + const dim_t page_size = bli_info_get_page_size(); + const dim_t page_size_b_floatx2 = + 2 * ( page_size / sizeof( float ) ); + + if ( ( m >= MT ) && ( n >= NT ) && ( k >= KT ) ) + { + if ( ( k > page_size_b_floatx2 ) || + ( ( k <= page_size_b_floatx2 ) && + ( m_ic > MT_2 ) && ( n_jc >= NT ) ) ) + { + bli_rntm_set_pack_a( 1, rntm_g ); + } + } +} + +#define GEN_LPGEMM_OPENMP_DECORATOR(A_type,B_type,C_type,LPGEMM_SFX) \ +void lpgemm_ ## LPGEMM_SFX ## _openmp_thread_decorator \ + ( \ + const dim_t m, \ + const dim_t n, \ + const dim_t k, \ + const A_type* a, \ + const dim_t rs_a, \ + const dim_t cs_a, \ + const AOCL_MEMORY_TAG mtag_a, \ + const B_type* b, \ + const dim_t rs_b, \ + const dim_t cs_b, \ + const AOCL_MEMORY_TAG mtag_b, \ + C_type* c, \ + const dim_t rs_c, \ + const dim_t cs_c, \ + C_type alpha, \ + C_type beta, \ + rntm_t* rntm_g, \ + lpgemm_post_op* post_op_list, \ + bool c_downscale \ + ) \ +{ \ + dim_t n_threads; \ + \ + /* Factorization of threads along m and n dimension respectively.*/ \ + dim_t ic_ways; \ + dim_t jc_ways; \ + \ + lpgemm_ ## LPGEMM_SFX ## _get_threading \ + ( \ + &n_threads, \ + &ic_ways, &jc_ways, \ + m, n, k, rntm_g \ + ); \ + \ + /* Set the packing block allocator field of the rntm. This will be + * inherited by all of the child threads when they make local copies of + * the rntm below.*/ \ + bli_membrk_rntm_set_membrk( rntm_g ); \ + \ + thrcomm_t static_lpgemm_comms[BLIS_LPGEMM_NUM_STATIC_COMMS]; \ + thrcomm_t* cur_lpgemm_comms = static_lpgemm_comms; \ + \ + if ( jc_ways > BLIS_LPGEMM_NUM_STATIC_COMMS ) \ + { \ + cur_lpgemm_comms = bli_malloc_intl( jc_ways * sizeof( thrcomm_t ) ); \ + } \ + for ( dim_t i = 0; i < jc_ways; ++i ) \ + { \ + bli_thrcomm_init( ic_ways, &cur_lpgemm_comms[i] ); \ + } \ + \ + _Pragma( "omp parallel num_threads(n_threads)" ) \ + { \ + /* Create a thread-local copy of the master thread's rntm_t. This is + * necessary since we want each thread to be able to track its own + * small block pool_t as it executes down the function stack.*/ \ + rntm_t rntm_l = *rntm_g; \ + \ + /* lpgemm_thrinfo_t object will be used to generate thrinfo_t objects + * for use in blis mt framework inside the respective mat mul driver + * functions.*/ \ + lpgemm_thrinfo_t thread; \ + thread.n_threads = n_threads; \ + thread.tid = omp_get_thread_num(); \ + thread.ic_ways = ic_ways; \ + thread.jc_ways = jc_ways; \ + thread.comm = cur_lpgemm_comms; \ + \ + lpgemm_rowvar_ ## LPGEMM_SFX \ + ( \ + m, n, k, \ + a, rs_a, cs_a, mtag_a, \ + b, rs_b, cs_b, mtag_b, \ + c, rs_c, cs_c,\ + alpha, \ + beta, \ + &rntm_l, \ + &thread, \ + post_op_list, c_downscale \ + ); \ + } \ + if ( jc_ways > BLIS_LPGEMM_NUM_STATIC_COMMS ) \ + { \ + bli_free_intl( cur_lpgemm_comms ); \ + } \ +} \ + +GEN_LPGEMM_OPENMP_DECORATOR(uint8_t,int8_t,int16_t,u8s8s16o16) +GEN_LPGEMM_OPENMP_DECORATOR(uint8_t,int8_t,int32_t,u8s8s32o32) +GEN_LPGEMM_OPENMP_DECORATOR(bfloat16,bfloat16,float,bf16bf16f32of32) +GEN_LPGEMM_OPENMP_DECORATOR(float,float,float,f32f32f32of32) + +#else + +#define GEN_LPGEMM_DECORATOR(A_type,B_type,C_type,LPGEMM_SFX) \ +void lpgemm_ ## LPGEMM_SFX ## _thread_decorator \ + ( \ + const dim_t m, \ + const dim_t n, \ + const dim_t k, \ + const A_type* a, \ + const dim_t rs_a, \ + const dim_t cs_a, \ + const AOCL_MEMORY_TAG mtag_a, \ + const B_type* b, \ + const dim_t rs_b, \ + const dim_t cs_b, \ + const AOCL_MEMORY_TAG mtag_b, \ + C_type* c, \ + const dim_t rs_c, \ + const dim_t cs_c, \ + C_type alpha, \ + C_type beta, \ + rntm_t* rntm_g, \ + lpgemm_post_op* post_op_list, \ + bool c_downscale \ + ) \ +{ \ + dim_t n_threads = 1; \ + \ + /* Factorization of threads along m and n dimension respectively.*/ \ + dim_t ic_ways = 1; \ + dim_t jc_ways = 1; \ + \ + /* Set the packing block allocator field of the rntm. This will be + * inherited by all of the child threads when they make local copies of + * the rntm below.*/ \ + bli_membrk_rntm_set_membrk( rntm_g ); \ + \ + thrcomm_t static_lpgemm_comm; \ + thrcomm_t* cur_lpgemm_comm = &static_lpgemm_comm; \ + \ + bli_thrcomm_init( ic_ways, cur_lpgemm_comm ); \ + \ + /* lpgemm_thrinfo_t object will be used to generate thrinfo_t objects + * for use in blis mt framework inside the respective mat mul driver + * functions.*/ \ + lpgemm_thrinfo_t thread; \ + thread.n_threads = n_threads; \ + thread.tid = 0; \ + thread.ic_ways = ic_ways; \ + thread.jc_ways = jc_ways; \ + thread.comm = cur_lpgemm_comm; \ + \ + lpgemm_rowvar_ ## LPGEMM_SFX \ + ( \ + m, n, k, \ + a, rs_a, cs_a, mtag_a, \ + b, rs_b, cs_b, mtag_b, \ + c, rs_c, cs_c, \ + alpha, \ + beta, \ + rntm_g, \ + &thread, \ + post_op_list, c_downscale \ + ); \ +} \ + +GEN_LPGEMM_DECORATOR(uint8_t,int8_t,int16_t,u8s8s16o16) +GEN_LPGEMM_DECORATOR(uint8_t,int8_t,int32_t,u8s8s32o32) +GEN_LPGEMM_DECORATOR(bfloat16,bfloat16,float,bf16bf16f32of32) +GEN_LPGEMM_DECORATOR(float,float,float,f32f32f32of32) + +#endif diff --git a/addon/aocl_gemm/frame/threading/lpgemm_thread_decor_openmp.h b/addon/aocl_gemm/frame/threading/lpgemm_thread_decor_openmp.h new file mode 100644 index 0000000000..8055d623e6 --- /dev/null +++ b/addon/aocl_gemm/frame/threading/lpgemm_thread_decor_openmp.h @@ -0,0 +1,106 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef LPGEMM_THREAD_DECOR_OPENMP_H +#define LPGEMM_THREAD_DECOR_OPENMP_H + +#include "lpgemm_types.h" +#include "lpgemm_post_ops.h" +#include "aocl_bf16_type.h" + +#ifdef BLIS_ENABLE_OPENMP + +#define GEN_LPGEMM_OPENMP_DECORATOR_FN(A_type,B_type,C_type,LPGEMM_SFX) \ +void lpgemm_ ## LPGEMM_SFX ## _openmp_thread_decorator \ + ( \ + const dim_t m, \ + const dim_t n, \ + const dim_t k, \ + const A_type* a, \ + const dim_t rs_a, \ + const dim_t cs_a, \ + const AOCL_MEMORY_TAG mtag_a, \ + const B_type* b, \ + const dim_t rs_b, \ + const dim_t cs_b, \ + const AOCL_MEMORY_TAG mtag_b, \ + C_type* c, \ + const dim_t rs_c, \ + const dim_t cs_c, \ + C_type alpha, \ + C_type beta, \ + rntm_t* rntm_g, \ + lpgemm_post_op* post_op_list, \ + bool c_downscale \ + ); \ + +GEN_LPGEMM_OPENMP_DECORATOR_FN(uint8_t,int8_t,int16_t,u8s8s16o16) +GEN_LPGEMM_OPENMP_DECORATOR_FN(uint8_t,int8_t,int32_t,u8s8s32o32) +GEN_LPGEMM_OPENMP_DECORATOR_FN(bfloat16,bfloat16,float,bf16bf16f32of32) +GEN_LPGEMM_OPENMP_DECORATOR_FN(float,float,float,f32f32f32of32) + +#else + +#define GEN_LPGEMM_DECORATOR_FN(A_type,B_type,C_type,LPGEMM_SFX) \ +void lpgemm_ ## LPGEMM_SFX ## _thread_decorator \ + ( \ + const dim_t m, \ + const dim_t n, \ + const dim_t k, \ + const A_type* a, \ + const dim_t rs_a, \ + const dim_t cs_a, \ + const AOCL_MEMORY_TAG mtag_a, \ + const B_type* b, \ + const dim_t rs_b, \ + const dim_t cs_b, \ + const AOCL_MEMORY_TAG mtag_b, \ + C_type* c, \ + const dim_t rs_c, \ + const dim_t cs_c, \ + C_type alpha, \ + C_type beta, \ + rntm_t* rntm_g, \ + lpgemm_post_op* post_op_list, \ + bool c_downscale \ + ); \ + +GEN_LPGEMM_DECORATOR_FN(uint8_t,int8_t,int16_t,u8s8s16o16) +GEN_LPGEMM_DECORATOR_FN(uint8_t,int8_t,int32_t,u8s8s32o32) +GEN_LPGEMM_DECORATOR_FN(bfloat16,bfloat16,float,bf16bf16f32of32) +GEN_LPGEMM_DECORATOR_FN(float,float,float,f32f32f32of32) + +#endif + +#endif //LPGEMM_THREAD_DECOR_OPENMP_H diff --git a/addon/aocl_gemm/frame/threading/lpgemm_thrinfo_utils.h b/addon/aocl_gemm/frame/threading/lpgemm_thrinfo_utils.h new file mode 100644 index 0000000000..2ac9b505a6 --- /dev/null +++ b/addon/aocl_gemm/frame/threading/lpgemm_thrinfo_utils.h @@ -0,0 +1,78 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef LPGEMM_THRINFO_UTILS_H +#define LPGEMM_THRINFO_UTILS_H + +// Parallelization only supported along jc and ic loops. Thus not reusing the +// existing thrinfo tree logic, since a light-weight work id generation will +// suffice. However the logic used for thread meta data generation, specific +// to jc and ic loops is borrowed. +BLIS_INLINE void lpgemm_gen_thrinfo + ( + lpgemm_thrinfo_t* thread, + thrinfo_t* thread_jc, + thrinfo_t* thread_ic + ) +{ + if ( thread == NULL ) + { + // Set n_ways=1 to ensure ST behaviour when thread is not initialized. + // This is the case when BLIS_ENABLE_OPENMP is not defined. + bli_thrinfo_set_ocomm_id( 0, thread_jc ); + bli_thrinfo_set_n_way( 1, thread_jc ); + bli_thrinfo_set_work_id( 0, thread_jc ); + + bli_thrinfo_set_ocomm_id( 0, thread_ic ); + bli_thrinfo_set_n_way( 1, thread_ic ); + bli_thrinfo_set_work_id( 0, thread_ic ); + } + else + { + // Replicate the logic in bli_l3_sup_thrinfo_create_root for jc thrinfo. + bli_thrinfo_set_ocomm_id( thread->tid, thread_jc ); + bli_thrinfo_set_n_way( thread->jc_ways, thread_jc ); + dim_t jc_work_id = thread->tid / thread->ic_ways; + bli_thrinfo_set_work_id( jc_work_id, thread_jc ); + + // Replicate the sub node creation logic in bli_thrinfo_sup_create_for_cntl + // for ic thrinfo. + dim_t ic_comm_id = thread->tid % thread->ic_ways; + bli_thrinfo_set_ocomm_id( ic_comm_id, thread_ic ); + bli_thrinfo_set_n_way( thread->ic_ways, thread_ic ); + bli_thrinfo_set_work_id( ic_comm_id, thread_ic ); + } +} + +#endif //LPGEMM_THRINFO_UTILS_H diff --git a/addon/aocl_gemm/frame/u8s8s16/lpgemm_reorder_s16.c b/addon/aocl_gemm/frame/u8s8s16/lpgemm_reorder_s16.c new file mode 100644 index 0000000000..0b55f31215 --- /dev/null +++ b/addon/aocl_gemm/frame/u8s8s16/lpgemm_reorder_s16.c @@ -0,0 +1,168 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ +#include "blis.h" +#include "lpgemm_utils.h" +#include "lpgemm_reorder_s16.h" +#include "lpgemm_packb_s16.h" +#include "lpgemm_config.h" + +void aocl_reorderb_nr32_u8s8s16o16 + ( + lpgemm_obj_t *b, + lpgemm_obj_t *b_reorder + ) +{ + const dim_t NC = lpgemm_get_block_size_NC_global_cntx(U8S8S16OS16); + const dim_t KC = lpgemm_get_block_size_KC_global_cntx(U8S8S16OS16); + const dim_t NR = lpgemm_get_block_size_NR_global_cntx(U8S8S16OS16); + + // Extracting the matrix properties from the lpgemm object + dim_t rs_b = b->rs; + dim_t n = b->width; + dim_t k = b->length; + + dim_t rs_b_reorder; + dim_t cs_b_reorder; + + dim_t k_updated = k; + + // Making multiple of 2 to suit k in vpmaddubsw + k_updated += (k_updated & 0x1); + + // Initialize a local runtime with global settings if necessary. Note + // that in the case that a runtime is passed in, we make a local copy. + rntm_t rntm_g; + bli_rntm_init_from_global( &rntm_g ); + + dim_t n_threads = bli_rntm_num_threads( &rntm_g ); + n_threads = ( n_threads > 0 ) ? n_threads : 1; + +#ifdef BLIS_ENABLE_OPENMP + _Pragma( "omp parallel num_threads(n_threads)" ) + { + // Initialise a local thrinfo obj for work split across threads. + thrinfo_t thread_jc; + bli_thrinfo_set_n_way( n_threads, &thread_jc ); + bli_thrinfo_set_work_id( omp_get_thread_num(), &thread_jc ); +#else + { + // Initialise a local thrinfo obj for work split across threads. + thrinfo_t thread_jc; + bli_thrinfo_set_n_way( 1, &thread_jc ); + bli_thrinfo_set_work_id( 0, &thread_jc ); +#endif + // Compute the JC loop thread range for the current thread. + dim_t jc_start, jc_end; + bli_thread_range_sub( &thread_jc, n, NR, FALSE, &jc_start, &jc_end ); + + for ( dim_t jc = jc_start; jc < jc_end; jc += NC ) + { + dim_t nc0 = bli_min( ( jc_end - jc ), NC ); + + dim_t jc_cur_loop = jc; + dim_t jc_cur_loop_rem = 0; + dim_t n_sub_updated; + + get_B_panel_reordered_start_offset_width + ( + jc, n, NC, 16, + &jc_cur_loop, &jc_cur_loop_rem, + &nc0, &n_sub_updated + ); + + for ( dim_t pc = 0; pc < k; pc += KC ) + { + dim_t kc0 = bli_min( ( k - pc ), KC ); + + // kc0 needs to be a multiple of 2 so that it can be used with + // vmaddubsw instruction. Padding is added in cases this + // condition is not satisfied, and therefore the kc0 offsets + // used for packed/reordered buffers needs to be updated. + dim_t kc0_updated = make_multiple_of_n( kc0, 2 ); + + // The offsets are calculated in such a way that it resembles + // the reorder buffer traversal in single threaded reordering. + // The panel boundaries (KCxNC) remain as it is accessed in + // single thread, and as a consequence a thread with jc_start + // inside the panel cannot consider NC range for reorder. It + // has to work with NC' < NC, and the offset is calulated using + // prev NC panels spanning k dim + cur NC panel spaning pc loop + // cur iteration + (NC - NC') spanning current kc0 (<= KC). + // + //Eg: Consider the following reordered buffer diagram: + // t1 t2 + // | | + // | |..NC..| + // | | | + // |.NC. |.NC. |NC'|NC" + // pc=0-+-----+-----+---+--+ + // KC| | | | | + // | 1 | 3 | 5 | + // pc=KC-+-----+-----+---st-+ + // KC| | | | | + // | 2 | 4 | 6 | 7| + // pc=k=2KC-+-----+-----+---+--+ + // |jc=0 |jc=NC|jc=2NC| + // + // The numbers 1,2..6,7 denotes the order in which reordered + // KCxNC blocks are stored in memory, ie: block 1 followed by 2 + // followed by 3, etc. Given two threads t1 and t2, and t2 needs + // to acces point st in the reorder buffer to write the data: + // The offset calulation logic will be: + // jc_cur_loop = 2NC, jc_cur_loop_rem = NC', pc = KC, + // n_sub_updated = NC, k = 2KC, kc0_updated = KC + // + // st = ( jc_cur_loop * k ) + // + ( n_sub_updated * pc ) + // + ( NC' * kc0_updated) + packb_nr32_u8s8s16o16 + ( + ( ( ( int8_t* )b_reorder->storage.aligned_buffer ) + + ( jc_cur_loop * k_updated ) + ( n_sub_updated * pc ) + + ( jc_cur_loop_rem * kc0_updated ) ), + ( ( ( int8_t* )b->storage.aligned_buffer ) + + ( rs_b * pc ) + jc ), + rs_b, nc0, kc0, &rs_b_reorder, &cs_b_reorder + ); + } + + adjust_B_panel_reordered_jc( &jc, jc_cur_loop ); + } + } + + // Changing the packed matrix properties in the packed matrix object + b_reorder->rs = rs_b_reorder; + b_reorder->cs = cs_b_reorder; + b_reorder->mtag = REORDERED; +} diff --git a/addon/aocl_gemm/frame/u8s8s16/lpgemm_reorder_s16.h b/addon/aocl_gemm/frame/u8s8s16/lpgemm_reorder_s16.h new file mode 100644 index 0000000000..6018978bc7 --- /dev/null +++ b/addon/aocl_gemm/frame/u8s8s16/lpgemm_reorder_s16.h @@ -0,0 +1,45 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ +#ifndef LPGEMM_REORDER_S16_H +#define LPGEMM_REORDER_S16_H + +#include "lpgemm_types.h" + +void aocl_reorderb_nr32_u8s8s16o16 + ( + lpgemm_obj_t *b, + lpgemm_obj_t *b_reorder + ); + +#endif // LPGEMM_REORDER_H diff --git a/addon/aocl_gemm/frame/u8s8s16/lpgemm_u8s8s16.c b/addon/aocl_gemm/frame/u8s8s16/lpgemm_u8s8s16.c new file mode 100644 index 0000000000..b8f5115429 --- /dev/null +++ b/addon/aocl_gemm/frame/u8s8s16/lpgemm_u8s8s16.c @@ -0,0 +1,339 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" +#include "lpgemm_5loop_interface_apis.h" +#include "lpgemm_packb_s16.h" +#include "lpgemm_kernels.h" +#include "lpgemm_utils.h" +#include "lpgemm_config.h" +#include "lpgemm_thrinfo_utils.h" + +// B should always be packed. +LPGEMM_5LOOP(uint8_t,int8_t,int16_t,u8s8s16o16) +{ + const dim_t NC = lpgemm_get_block_size_NC_global_cntx( U8S8S16OS16 ); + const dim_t KC = lpgemm_get_block_size_KC_global_cntx( U8S8S16OS16 ); + const dim_t MC = lpgemm_get_block_size_MC_global_cntx( U8S8S16OS16 ); + const dim_t NR = lpgemm_get_block_size_NR_global_cntx( U8S8S16OS16 ); + const dim_t MR = lpgemm_get_block_size_MR_global_cntx( U8S8S16OS16 ); + + if (mtag_b == UNPACKED) + { + // Error: can only work with packed B now. + return; + } + + const int8_t *b_use; + const uint8_t *a_use; + dim_t rs_a_use = rs_a; + dim_t cs_a_use = cs_a; + + dim_t rs_b_use = rs_b; + dim_t cs_b_use = cs_b; + + int16_t *c_use_jc = NULL; + int16_t *c_use_ic = NULL; + dim_t rs_c_use = rs_c; + dim_t rs_c_downscale = rs_c; + + // Pack buffer for B. + int8_t *pack_b_buffer_u8s8s16o16; + mem_t mem_b = BLIS_MEM_INITIALIZER; + dim_t packb_min_NR = 16; + siz_t mem_b_size_req = 0; + + // Temporary buffer for C accumulation when downscaling is required. + int16_t* temp_scal_c_buffer_u8s8s16o16; + mem_t mem_scale_c = BLIS_MEM_INITIALIZER; + siz_t mem_scale_c_size_req = 0; + + // Making multiple of 2 to suit k in vpmaddubsw + dim_t k_updated = make_multiple_of_n( k, 2 ); + + // Is required to decide whether to apply post ops or not. + bool is_last_k = FALSE; + + // Generate thrinfo objects for jc and ic loops from lpgemm_thrinfo_t. + thrinfo_t thread_jc; + thrinfo_t thread_ic; + + lpgemm_gen_thrinfo(thread, &thread_jc, &thread_ic); + + // Compute the JC, IC loop thread range for the current thread. + dim_t jc_start, jc_end; + bli_thread_range_sub(&thread_jc, n, NR, FALSE, &jc_start, &jc_end); + + dim_t ic_start, ic_end; + bli_thread_range_sub(&thread_ic, m, MR, FALSE, &ic_start, &ic_end); + + for (dim_t jc = jc_start; jc < jc_end; jc += NC) + { + dim_t nc0 = bli_min((jc_end - jc), NC); + + dim_t jc_cur_loop = jc; + dim_t jc_cur_loop_rem = 0; + dim_t n_sub_updated = 0; + + if (mtag_b == REORDERED) + { + get_B_panel_reordered_start_offset_width + ( + jc, n, NC, packb_min_NR, + &jc_cur_loop, &jc_cur_loop_rem, + &nc0, &n_sub_updated + ); + } + + if ( c_downscale == FALSE ) + { + c_use_jc = c + jc; + } + // Temp accumulaton buffer for C allocation. + else if ( c_downscale == TRUE ) + { + mem_scale_c_size_req = sizeof( int16_t ) * nc0 * ( ic_end - ic_start ); + + lpgemm_alloc_mem_panel + ( + mem_scale_c_size_req, BLIS_BUFFER_FOR_C_PANEL, + &mem_scale_c, rntm + ); + + temp_scal_c_buffer_u8s8s16o16 = bli_mem_buffer( &mem_scale_c ); + + c_use_jc = ( int16_t* )temp_scal_c_buffer_u8s8s16o16; + + if ( beta != 0 ) + { + dim_t i_temp = 0; + dim_t j_temp = 0; + // Upscale out C to temporary C matrix. + for ( dim_t i_dscale = ic_start; i_dscale < ic_end; ++i_dscale ) + { + j_temp = 0; + for ( dim_t j_dscale = jc; j_dscale < ( jc + nc0 ); ++j_dscale ) + { + *( temp_scal_c_buffer_u8s8s16o16 + + ( nc0 * i_temp ) + j_temp ) = + ( int16_t )( *( ( ( int8_t* )c ) + + ( rs_c * i_dscale ) + j_dscale ) ); + + j_temp++; + } + i_temp++; + } + } + + // The temp c buffer stride is modified as opposed to original C matrix. + rs_c_use = nc0; + } + + for (dim_t pc = 0; pc < k; pc += KC) + { + int16_t beta0 = (pc == 0) ? beta : 1; + dim_t kc0 = bli_min((k - pc), KC); + + is_last_k = ( ( pc + KC ) >= k ) ? ( TRUE ) : ( FALSE ); + + // kc0 needs to be a multiple of 2 so that it can be + // used with vpmaddubsw instruction. Padding is added in + // cases this condition is not satisfied, and therefore + // the kc0 offsets used for packed/reordered buffers + // needs to be updated. + dim_t kc0_updated = make_multiple_of_n(kc0, 2); + + if (mtag_b == PACK) + { + // Pack B chunks are based on jc work id. + dim_t jc_work_id = bli_thread_work_id(&thread_jc); + + // Using child thrinfo (thread_ic) tid to decide chief thread + // per B matrix chunk (jc work id group) + if (bli_thread_am_ochief(&thread_ic)) + { + // nc0 needs to be a multiple of 16 since this gives maximum + // vectorization. Packing B always results in buffers with width + // which is a multiple of 16. Subsequently the nc0 offsets used + // for packed/reordered buffers needs to be updated. + dim_t nc0_updated = make_multiple_of_n(nc0, packb_min_NR); + mem_b_size_req = sizeof(int8_t) * nc0_updated * kc0_updated; + + lpgemm_alloc_mem_panel( + mem_b_size_req, BLIS_BUFFER_FOR_B_PANEL, + &mem_b, rntm); + + thread->comm[jc_work_id].sent_object = + bli_mem_buffer(&mem_b); + } + + // All threads in work group should wait till chief thread has + // finished allocating the packing buffers. + bli_thrcomm_barrier( + bli_thread_ocomm_id(&thread_ic), + &thread->comm[jc_work_id]); + + pack_b_buffer_u8s8s16o16 = + (int8_t *)thread->comm[jc_work_id].sent_object; + + // Compute the B panel per thread loop range for parallel + // packing using ic_ways number of threads. Since atmost only + // ic_ways threads can be used, the thread_ic attributes are + // used to split the loop range. + dim_t jc_packb_start, jc_packb_end; + bli_thread_range_sub + ( + &thread_ic, nc0, NR, FALSE, + &jc_packb_start, &jc_packb_end + ); + + // Ensure thread ranges are valid, especially cases where no: + // of threads available for parallelization are greater than + // no: of B panel NR chunks. + if ((jc_packb_end > jc_packb_start) && + (jc_packb_start < (jc + nc0))) + { + packb_nr32_u8s8s16o16 + ( + pack_b_buffer_u8s8s16o16 + + (jc_packb_start * kc0_updated), + (b + (rs_b * pc) + (cs_b * jc) + + (cs_b * jc_packb_start)), + rs_b, + (jc_packb_end - jc_packb_start), kc0, + &rs_b_use, &cs_b_use + ); + } + else + { + get_packb_nr32_u8s8s16o16_strides(&rs_b_use, &cs_b_use); + } + + // All threads in work group should wait till B matrix packing + // is completed by the participating threads. + bli_thrcomm_barrier + ( + bli_thread_ocomm_id(&thread_ic), + &thread->comm[jc_work_id] + ); + + b_use = pack_b_buffer_u8s8s16o16; + } + else if (mtag_b == REORDERED) + { + // In multi-threaded scenarios, an extra offset into a given + // packed B panel is required, since the jc loop split can + // result in per thread start offset inside the panel, instead + // of panel boundaries. + b_use = b + (jc_cur_loop * k_updated) + + (n_sub_updated * pc) + + (jc_cur_loop_rem * kc0_updated); + + get_packb_nr32_u8s8s16o16_strides(&rs_b_use, &cs_b_use); + } + else + { + // Unpacked B not supported. + return; + } + + for (dim_t ic = ic_start; ic < ic_end; ic += MC) + { + dim_t mc0 = bli_min((ic_end - ic), MC); + + // Only per thread C matrix is stored in temp buffer, so both + // per thread jc and ic start should be normalized to zero. + if ( c_downscale == TRUE ) + { + c_use_ic = c_use_jc + ( rs_c_use * ( ic - ic_start ) ); + } + else + { + c_use_ic = c_use_jc + ( rs_c_use * ic ); + } + + a_use = a + (rs_a * ic) + (cs_a * pc); + cs_a_use = 1; + + dim_t a_block_stride = rs_a; + + for (dim_t jr = 0; jr < nc0; jr += NR) + { + dim_t nr0 = bli_min((nc0 - jr), NR); + + // Calls for reorder B + lpgemm_rowvar_u8s8s16o16_6x32 + ( + mc0, nr0, kc0, + a_use, rs_a_use, cs_a_use, a_block_stride, + (b_use + (jr * kc0_updated)), rs_b_use, cs_b_use, + (c_use_ic + jr), rs_c_use, 1, + alpha, beta0, + is_last_k, ic, ( jc + jr ), post_op_list, rs_c_downscale + ); + } + } + } + + if (mtag_b == REORDERED) + { + adjust_B_panel_reordered_jc(&jc, jc_cur_loop); + } + } + + // Release pack buffers. + if (mtag_b == PACK) + { + // All threads in work group should wait till B matrix usage is + // completed by the participating threads. + bli_thrcomm_barrier( + bli_thread_ocomm_id(&thread_jc), + &thread->comm[bli_thread_work_id(&thread_jc)]); + + if (bli_thread_am_ochief(&thread_ic)) + { + if (bli_mem_is_alloc(&mem_b)) + { + bli_membrk_release(rntm, &mem_b); + } + } + } + if ( c_downscale == TRUE ) + { + if ( bli_mem_is_alloc( &mem_scale_c ) ) + { + bli_membrk_release( rntm, &mem_scale_c ); + } + } +} diff --git a/addon/aocl_gemm/frame/u8s8s32/lpgemm_reorder.c b/addon/aocl_gemm/frame/u8s8s32/lpgemm_reorder.c new file mode 100644 index 0000000000..746a134100 --- /dev/null +++ b/addon/aocl_gemm/frame/u8s8s32/lpgemm_reorder.c @@ -0,0 +1,231 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" +#include "lpgemm_utils.h" +#include "lpgemm_reorder.h" +#include "lpgemm_packa.h" +#include "lpgemm_packb.h" +#include "lpgemm_config.h" + +void reorderb_nr64_u8s8s32o32 + ( + lpgemm_obj_t* b, + lpgemm_obj_t* b_reorder + ) +{ + dim_t NC = lpgemm_get_block_size_NC_global_cntx( U8S8S32OS32 ); + dim_t NR = lpgemm_get_block_size_NR_global_cntx( U8S8S32OS32 ); + dim_t KC = lpgemm_get_block_size_KC_global_cntx( U8S8S32OS32 ); + + dim_t rs_b = b->rs; + dim_t rs_b_reorder; + dim_t cs_b_reorder; + + dim_t n = b->width; + dim_t k = b->length; + + // k needs to be a multiple of 4 so that it can be used with vpdpbusd + // instruction. Padding is added in cases this condition is not + // satisfied, and therefore the k offset used for packed/reordered + // buffer needs to be updated. + dim_t k_updated = make_multiple_of_n( k, 4 ); + + // Initialize a local runtime with global settings if necessary. Note + // that in the case that a runtime is passed in, we make a local copy. + rntm_t rntm_g; + bli_rntm_init_from_global( &rntm_g ); + + dim_t n_threads = bli_rntm_num_threads( &rntm_g ); + n_threads = ( n_threads > 0 ) ? n_threads : 1; + +#ifdef BLIS_ENABLE_OPENMP + _Pragma( "omp parallel num_threads(n_threads)" ) + { + // Initialise a local thrinfo obj for work split across threads. + thrinfo_t thread_jc; + bli_thrinfo_set_n_way( n_threads, &thread_jc ); + bli_thrinfo_set_work_id( omp_get_thread_num(), &thread_jc ); +#else + { + // Initialise a local thrinfo obj for work split across threads. + thrinfo_t thread_jc; + bli_thrinfo_set_n_way( 1, &thread_jc ); + bli_thrinfo_set_work_id( 0, &thread_jc ); +#endif + // Compute the JC loop thread range for the current thread. + dim_t jc_start, jc_end; + bli_thread_range_sub( &thread_jc, n, NR, FALSE, &jc_start, &jc_end ); + + for ( dim_t jc = jc_start; jc < jc_end; jc += NC ) + { + dim_t nc0 = bli_min( ( jc_end - jc ), NC ); + + dim_t jc_cur_loop = jc; + dim_t jc_cur_loop_rem = 0; + dim_t n_sub_updated; + + get_B_panel_reordered_start_offset_width + ( + jc, n, NC, get_packb_u8s8s32o32_min_NR(), + &jc_cur_loop, &jc_cur_loop_rem, + &nc0, &n_sub_updated + ); + + for ( dim_t pc = 0; pc < k; pc += KC ) + { + dim_t kc0 = bli_min( ( k - pc ), KC ); + + // kc0 needs to be a multiple of 4 so that it can be used with + // vpdpbusd instruction. Padding is added in cases this + // condition is not satisfied, and therefore the kc0 offsets + // used for packed/reordered buffers needs to be updated. + dim_t kc0_updated = make_multiple_of_n( kc0, 4 ); + + // The offsets are calculated in such a way that it resembles + // the reorder buffer traversal in single threaded reordering. + // The panel boundaries (KCxNC) remain as it is accessed in + // single thread, and as a consequence a thread with jc_start + // inside the panel cannot consider NC range for reorder. It + // has to work with NC' < NC, and the offset is calulated using + // prev NC panels spanning k dim + cur NC panel spaning pc loop + // cur iteration + (NC - NC') spanning current kc0 (<= KC). + // + //Eg: Consider the following reordered buffer diagram: + // t1 t2 + // | | + // | |..NC..| + // | | | + // |.NC. |.NC. |NC'|NC" + // pc=0-+-----+-----+---+--+ + // KC| | | | | + // | 1 | 3 | 5 | + // pc=KC-+-----+-----+---st-+ + // KC| | | | | + // | 2 | 4 | 6 | 7| + // pc=k=2KC-+-----+-----+---+--+ + // |jc=0 |jc=NC|jc=2NC| + // + // The numbers 1,2..6,7 denotes the order in which reordered + // KCxNC blocks are stored in memory, ie: block 1 followed by 2 + // followed by 3, etc. Given two threads t1 and t2, and t2 needs + // to acces point st in the reorder buffer to write the data: + // The offset calulation logic will be: + // jc_cur_loop = 2NC, jc_cur_loop_rem = NC', pc = KC, + // n_sub_updated = NC, k = 2KC, kc0_updated = KC + // + // st = ( jc_cur_loop * k ) + // + ( n_sub_updated * pc ) + // + ( NC' * kc0_updated) +#ifdef BLIS_KERNELS_ZEN4 + packb_nr64_u8s8s32o32 + ( + ( ( ( int8_t* )b_reorder->storage.aligned_buffer ) + + ( jc_cur_loop * k_updated ) + ( n_sub_updated * pc ) + + ( jc_cur_loop_rem * kc0_updated ) ), + ( ( ( int8_t* )b->storage.aligned_buffer ) + + ( rs_b * pc ) + jc ), + rs_b, nc0, kc0, &rs_b_reorder, &cs_b_reorder + ); +#else + // Silence compiler warnings. + rs_b_reorder = 0; + cs_b_reorder = 0; + ( void )kc0_updated; + ( void )k_updated; + ( void )rs_b; +#endif + } + + adjust_B_panel_reordered_jc( &jc, jc_cur_loop ); + } + } + + b_reorder->rs = rs_b_reorder; + b_reorder->cs = cs_b_reorder; + b_reorder->mtag = REORDERED; +} + +void reordera_mr6_u8s8s32o32 + ( + lpgemm_obj_t* a, + lpgemm_obj_t* a_reorder + ) +{ + dim_t MC = lpgemm_get_block_size_MC_global_cntx( U8S8S32OS32 ); + dim_t KC = lpgemm_get_block_size_KC_global_cntx( U8S8S32OS32 ); + + dim_t rs_a = a->rs; + dim_t rs_a_reorder; + dim_t cs_a_reorder; + + dim_t k = a->width; + dim_t m = a->length; + + for ( dim_t pc = 0; pc < k; pc += KC ) + { + dim_t kc0 = bli_min( ( k - pc ), KC ); + + // kc0 needs to be a multiple of 4 so that it can be used with + // vpdpbusd instruction. Padding is added in cases this + // condition is not satisfied, and therefore the kc0 offsets + // used for packed/reordered buffers needs to be updated. + dim_t kc0_updated = make_multiple_of_n( kc0, 4 ); + + for ( dim_t ic = 0; ic < m; ic += MC ) + { + dim_t mc0 = bli_min( ( m - ic ), MC ); + +#ifdef BLIS_KERNELS_ZEN4 + packa_k64_u8s8s32o32 + ( + ( ( ( uint8_t* )a_reorder->storage.aligned_buffer ) + ( pc * m ) + + ( ic * kc0_updated ) ), + ( ( ( uint8_t* )a->storage.aligned_buffer ) + ( rs_a * ic ) + pc ), + rs_a, mc0, kc0, &rs_a_reorder, &cs_a_reorder + ); +#else + rs_a_reorder = 0; + cs_a_reorder = 0; + ( void )kc0_updated; + ( void )rs_a; + ( void )mc0; +#endif + } + } + + a_reorder->rs = rs_a_reorder; + a_reorder->cs = cs_a_reorder; + a_reorder->mtag = REORDERED; +} diff --git a/addon/aocl_gemm/frame/u8s8s32/lpgemm_reorder.h b/addon/aocl_gemm/frame/u8s8s32/lpgemm_reorder.h new file mode 100644 index 0000000000..eb8dad9cfc --- /dev/null +++ b/addon/aocl_gemm/frame/u8s8s32/lpgemm_reorder.h @@ -0,0 +1,52 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef LPGEMM_REORDER_H +#define LPGEMM_REORDER_H + +#include "lpgemm_types.h" + +void reorderb_nr64_u8s8s32o32 + ( + lpgemm_obj_t* b, + lpgemm_obj_t* b_reorder + ); + +void reordera_mr6_u8s8s32o32 + ( + lpgemm_obj_t* a, + lpgemm_obj_t* a_reorder + ); + +#endif //LPGEMM_REORDER_H diff --git a/addon/aocl_gemm/frame/u8s8s32/lpgemm_u8s8s32.c b/addon/aocl_gemm/frame/u8s8s32/lpgemm_u8s8s32.c new file mode 100644 index 0000000000..82a745fcf5 --- /dev/null +++ b/addon/aocl_gemm/frame/u8s8s32/lpgemm_u8s8s32.c @@ -0,0 +1,410 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" +#include "lpgemm_5loop_interface_apis.h" +#include "lpgemm_packa.h" +#include "lpgemm_packb.h" +#include "lpgemm_kernels.h" +#include "lpgemm_utils.h" +#include "lpgemm_thrinfo_utils.h" +#include "lpgemm_config.h" + +// B should always be packed. +LPGEMM_5LOOP(uint8_t,int8_t,int32_t,u8s8s32o32) +{ + dim_t NC = lpgemm_get_block_size_NC_global_cntx( U8S8S32OS32 ); + dim_t KC = lpgemm_get_block_size_KC_global_cntx( U8S8S32OS32 ); + dim_t MC = lpgemm_get_block_size_MC_global_cntx( U8S8S32OS32 ); + dim_t NR = lpgemm_get_block_size_NR_global_cntx( U8S8S32OS32 ); + dim_t MR = lpgemm_get_block_size_MR_global_cntx( U8S8S32OS32 ); + + if ( mtag_b == UNPACKED ) + { + //Error: can only work with packed B now. + return; + } + + // Strides are updated based on matrix packing/reordering. + const uint8_t* a_use = NULL; + dim_t rs_a_use = rs_a; + dim_t cs_a_use = cs_a; + dim_t a_block_stride = 0; + + const int8_t* b_use = NULL; + dim_t rs_b_use = rs_b; + dim_t cs_b_use = cs_b; + + int32_t* c_use_jc = NULL; + int32_t* c_use_ic = NULL; + dim_t rs_c_use = rs_c; + dim_t rs_c_downscale = rs_c; + + // Pack buffer for A. + uint8_t* pack_a_buffer_u8s8s32o32; + mem_t mem_a = BLIS_MEM_INITIALIZER; + siz_t mem_a_size_req = 0; + + // Pack buffer for B. + int8_t* pack_b_buffer_u8s8s32o32; + mem_t mem_b = BLIS_MEM_INITIALIZER; + siz_t mem_b_size_req = 0; + dim_t packb_min_NR = get_packb_u8s8s32o32_min_NR(); + + // Temporary buffer for C accumulation when downscaling is required. + int32_t* temp_scal_c_buffer_u8s8s32o32; + mem_t mem_scale_c = BLIS_MEM_INITIALIZER; + siz_t mem_scale_c_size_req = 0; + + // kc needs to be a multiple of 4 so that it can be used with vpdpbusd + // instruction. Padding is added in cases this condition is not + // satisfied, and therefore the k offset used for packed/reordered + // buffer needs to be updated. + dim_t k_updated = make_multiple_of_n( k, 4 ); + + // Is required to decide whether to apply post ops or not. + bool is_last_k = FALSE; + + // Generate thrinfo objects for jc and ic loops from lpgemm_thrinfo_t. + thrinfo_t thread_jc; + thrinfo_t thread_ic; + + lpgemm_gen_thrinfo( thread, &thread_jc, &thread_ic ); + + // Compute the JC, IC loop thread range for the current thread. + dim_t jc_start, jc_end; + bli_thread_range_sub( &thread_jc, n, NR, FALSE, &jc_start, &jc_end ); + + dim_t ic_start, ic_end; + bli_thread_range_sub( &thread_ic, m, MR, FALSE, &ic_start, &ic_end ); + + for ( dim_t jc = jc_start; jc < jc_end; jc += NC ) + { + dim_t nc0 = bli_min( ( jc_end - jc ), NC ); + + dim_t jc_cur_loop = jc; + dim_t jc_cur_loop_rem = 0; + dim_t n_sub_updated; + + if ( mtag_b == REORDERED ) + { + get_B_panel_reordered_start_offset_width + ( + jc, n, NC, packb_min_NR, + &jc_cur_loop, &jc_cur_loop_rem, + &nc0, &n_sub_updated + ); + } + + if ( c_downscale == FALSE ) + { + c_use_jc = c + jc; + } + // Temp accumulaton buffer for C allocation. + else if ( c_downscale == TRUE ) + { + mem_scale_c_size_req = sizeof( int32_t ) * nc0 * ( ic_end - ic_start ); + + lpgemm_alloc_mem_panel + ( + mem_scale_c_size_req, BLIS_BUFFER_FOR_C_PANEL, + &mem_scale_c, rntm + ); + + temp_scal_c_buffer_u8s8s32o32 = bli_mem_buffer( &mem_scale_c ); + + c_use_jc = ( int32_t* )temp_scal_c_buffer_u8s8s32o32; + + if ( beta != 0 ) + { + dim_t i_temp = 0; + dim_t j_temp = 0; + // Upscale out C to temporary C matrix. + for ( dim_t i_dscale = ic_start; i_dscale < ic_end; ++i_dscale ) + { + j_temp = 0; + for ( dim_t j_dscale = jc; j_dscale < ( jc + nc0 ); ++j_dscale ) + { + *( temp_scal_c_buffer_u8s8s32o32 + + ( nc0 * i_temp ) + j_temp ) = + ( int32_t )( *( ( ( int8_t* )c ) + + ( rs_c * i_dscale ) + j_dscale ) ); + + j_temp++; + } + i_temp++; + } + } + + // The temp c buffer stride is modified as opposed to original C matrix. + rs_c_use = nc0; + } + + for ( dim_t pc = 0; pc < k; pc += KC ) + { + int32_t beta0 = ( pc == 0 ) ? beta : 1; + dim_t kc0 = bli_min( ( k - pc ), KC ); + + // kc0 needs to be a multiple of 4 so that it can be + // used with vpdpbusd instruction. Padding is added in + // cases this condition is not satisfied, and therefore + // the kc0 offsets used for packed/reordered buffers + // needs to be updated. + dim_t kc0_updated = make_multiple_of_n( kc0, 4 ); + + is_last_k = ( ( pc + KC ) >= k ) ? ( TRUE ) : ( FALSE ); + + if ( mtag_b == PACK ) + { + // Pack B chunks are based on jc work id. + dim_t jc_work_id = bli_thread_work_id( &thread_jc ); + + // Using child thrinfo (thread_ic) tid to decide chief thread + // per B matrix chunk (jc work id group) + if ( bli_thread_am_ochief( &thread_ic ) ) + { + // nc0 needs to be a multiple of 16 since this gives maximum + // vectorization. Packing B always results in buffers with width + // which is a multiple of 16. Subsequently the nc0 offsets used + // for packed/reordered buffers needs to be updated. + dim_t nc0_updated = make_multiple_of_n( nc0, packb_min_NR ); + mem_b_size_req = sizeof( int8_t ) * nc0_updated * kc0_updated; + + lpgemm_alloc_mem_panel + ( + mem_b_size_req, BLIS_BUFFER_FOR_B_PANEL, + &mem_b, rntm + ); + + thread->comm[jc_work_id].sent_object = + bli_mem_buffer( &mem_b ); + } + + // All threads in work group should wait till chief thread has + // finished allocating the packing buffers. + bli_thrcomm_barrier + ( + bli_thread_ocomm_id( &thread_ic ), + &thread->comm[jc_work_id] + ); + + pack_b_buffer_u8s8s32o32 = + ( int8_t* ) thread->comm[jc_work_id].sent_object; + + // Compute the B panel per thread loop range for parallel + // packing using ic_ways number of threads. Since atmost only + // ic_ways threads can be used, the thread_ic attributes are + // used to split the loop range. + dim_t jc_packb_start, jc_packb_end; + bli_thread_range_sub + ( + &thread_ic, nc0, NR, FALSE, + &jc_packb_start, &jc_packb_end + ); + + // Ensure thread ranges are valid, especially cases where no: + // of threads available for parallelization are greater than + // no: of B panel NR chunks. + if ( ( jc_packb_end > jc_packb_start ) && + ( jc_packb_start < ( jc + nc0 ) ) ) + { +#ifdef BLIS_KERNELS_ZEN4 + packb_nr64_u8s8s32o32 + ( + pack_b_buffer_u8s8s32o32 + ( jc_packb_start * kc0_updated ), + ( b + ( rs_b * pc ) + ( cs_b * jc ) + + ( cs_b * jc_packb_start ) ), rs_b, + ( jc_packb_end - jc_packb_start ), kc0, + &rs_b_use, &cs_b_use + ); +#endif + } + else + { + get_packb_nr64_u8s8s32o32_strides( &rs_b_use, &cs_b_use ); + } + + // All threads in work group should wait till B matrix packing + // is completed by the participating threads. + bli_thrcomm_barrier + ( + bli_thread_ocomm_id( &thread_ic ), + &thread->comm[jc_work_id] + ); + b_use = pack_b_buffer_u8s8s32o32; + } + else if ( mtag_b == REORDERED ) + { + // In multi-threaded scenarios, an extra offset into a given + // packed B panel is required, since the jc loop split can + // result in per thread start offset inside the panel, instead + // of panel boundaries. + b_use = b + ( jc_cur_loop * k_updated ) + + ( n_sub_updated * pc ) + + ( jc_cur_loop_rem * kc0_updated ); + + get_packb_nr64_u8s8s32o32_strides( &rs_b_use, &cs_b_use ); + } + else + { + //Unpacked B not supported. + return; + } + + for ( dim_t ic = ic_start; ic < ic_end; ic += MC ) + { + dim_t mc0 = bli_min( ( ic_end - ic ), MC ); + + // Only per thread C matrix is stored in temp buffer, so both + // per thread jc and ic start should be normalized to zero. + if ( c_downscale == TRUE ) + { + c_use_ic = c_use_jc + ( rs_c_use * ( ic - ic_start ) ); + } + else + { + c_use_ic = c_use_jc + ( rs_c_use * ic ); + } + + // Matrix A packed and reordered code path is not triggerred + // currently since we do not support it yet. + if ( mtag_a == PACK ) + { + mem_a_size_req = sizeof( uint8_t ) * mc0 * kc0_updated; + + lpgemm_alloc_mem_panel + ( + mem_a_size_req, BLIS_BUFFER_FOR_A_BLOCK, + &mem_a, rntm + ); + pack_a_buffer_u8s8s32o32 = ( uint8_t* )bli_mem_buffer( &mem_a ); + +#ifdef BLIS_KERNELS_ZEN4 + packa_k64_u8s8s32o32 + ( + pack_a_buffer_u8s8s32o32, + ( a + ( rs_a * ic ) + pc ), rs_a, + mc0, kc0, + &rs_a_use, &cs_a_use + ); +#endif + a_use = pack_a_buffer_u8s8s32o32; + a_block_stride = kc0_updated; + } + else if ( mtag_a == REORDERED ) + { + get_packa_k64_u8s8s32o32_strides( &rs_a_use, &cs_a_use ); + a_use = a + ( pc * m ) + ( kc0_updated * ic ); + a_block_stride = kc0_updated; + } + else + { + a_use = a + ( rs_a * ic ) + ( cs_a * pc ); + + // Int8 kernel reads 4 elements, totalling 4 bytes in a + // single broadcast for use in vnni instruction. + // Non vnni based kernel requires update to this code. + cs_a_use = 4; + a_block_stride = rs_a; + } + + for ( dim_t jr = 0; jr < nc0; jr += NR ) + { + dim_t nr0 = bli_min( ( nc0 - jr ), NR ); + +#ifdef BLIS_KERNELS_ZEN4 + // Reorder/Packed B, Reorder/Packed/Unpacked A call. + lpgemm_rowvar_u8s8s32o32_6x64 + ( + mc0, nr0, kc0, + a_use, rs_a_use, cs_a_use, a_block_stride, + ( b_use + ( jr * kc0_updated ) ), rs_b_use, cs_b_use, + ( c_use_ic + jr ), rs_c_use, 1, + alpha, beta0, + is_last_k, ic, ( jc + jr ), post_op_list, rs_c_downscale + ); +#else + // Silence compiler warnings. + ( void )b_use; + ( void )a_block_stride; + ( void )rs_c_downscale; + ( void )is_last_k; + ( void )c_use_ic; + ( void )a_use; + ( void )beta0; + ( void )nr0; +#endif + } + } + } + if ( mtag_b == REORDERED ) + { + adjust_B_panel_reordered_jc( &jc, jc_cur_loop ); + } + } + + // Release pack buffers. + if ( mtag_b == PACK ) + { + // All threads in work group should wait till B matrix usage is + // completed by the participating threads. + bli_thrcomm_barrier + ( + bli_thread_ocomm_id( &thread_jc ), + &thread->comm[bli_thread_work_id( &thread_jc)] + ); + + if ( bli_thread_am_ochief( &thread_ic ) ) + { + if ( bli_mem_is_alloc( &mem_b ) ) + { + bli_membrk_release( rntm, &mem_b ); + } + } + } + if ( mtag_a == PACK ) + { + if ( bli_mem_is_alloc( &mem_a ) ) + { + bli_membrk_release( rntm, &mem_a ); + } + } + if ( c_downscale == TRUE ) + { + if ( bli_mem_is_alloc( &mem_scale_c ) ) + { + bli_membrk_release( rntm, &mem_scale_c ); + } + } +} diff --git a/addon/aocl_gemm/frame/u8s8s32/lpgemm_utils.c b/addon/aocl_gemm/frame/u8s8s32/lpgemm_utils.c new file mode 100644 index 0000000000..aa6669469d --- /dev/null +++ b/addon/aocl_gemm/frame/u8s8s32/lpgemm_utils.c @@ -0,0 +1,156 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" +#include "lpgemm_utils.h" + +dim_t get_64byte_aligned_memory + ( + void** original_memory, + void** aligned_memory, + int64_t allocate_size + ) +{ + // Get 64 byte aligned memory. + int8_t* t1_original = ( int8_t* ) malloc( allocate_size + 64 ); + if ( t1_original == NULL ) + { + //Error in malloc. + *original_memory = NULL; + *aligned_memory = NULL; + return -1; + } + + int8_t* ta_original = t1_original + 64; + ta_original = ta_original - ( ( int64_t )( ta_original ) % 64 ); + + *original_memory = t1_original; + *aligned_memory = ta_original; + return 0; +} + +static lpgemm_obj_t* alloc_lpgemm_obj_t_u8s8s32 + ( + dim_t length, + dim_t width, + dim_t stride, + dim_t elem_size, + AOCL_STOR_TAG stor_scheme, + AOCL_MEMORY_TAG mtag + ) +{ + lpgemm_obj_t* obj = ( lpgemm_obj_t* ) malloc( sizeof( lpgemm_obj_t ) ); + + if ( obj == NULL ) + { + return NULL; //failure + } + + // Allocate aligned buffers. + get_64byte_aligned_memory( &obj->storage.origin_buffer, + &obj->storage.aligned_buffer, + ( elem_size * length * width ) ); + + if ( obj->storage.origin_buffer == NULL ) + { + // Buffer allocation failed. + free( obj ); + return NULL; + } + + obj->length = length; + obj->width = width; + obj->elem_size = elem_size; + + if ( stor_scheme == ROW_MAJOR ) + { + obj->rs = stride; + obj->cs = 4; // 4 elements read at a time. + } + else if ( stor_scheme == COLUMN_MAJOR ) + { + obj->cs = stride; + obj->rs = 1; + } + obj->mtag = mtag; + + return obj; +} + +lpgemm_obj_t* alloc_unpack_tag_lpgemm_obj_t_u8s8s32 + ( + dim_t length, + dim_t width, + dim_t stride, + dim_t elem_size, + AOCL_STOR_TAG stor_scheme + ) +{ + return alloc_lpgemm_obj_t_u8s8s32( length, width, stride, elem_size, stor_scheme, UNPACKED ); +} + +lpgemm_obj_t* alloc_pack_tag_lpgemm_obj_t_u8s8s32 + ( + dim_t length, + dim_t width, + dim_t stride, + dim_t elem_size, + AOCL_STOR_TAG stor_scheme + ) +{ + return alloc_lpgemm_obj_t_u8s8s32( length, width, stride, elem_size, stor_scheme, PACK ); +} + +lpgemm_obj_t* alloc_reorder_tag_lpgemm_obj_t_u8s8s32 + ( + dim_t length, + dim_t width, + dim_t stride, + dim_t elem_size, + AOCL_STOR_TAG stor_scheme + ) +{ + // Extra space since packing does width in multiples of 16. + dim_t width_reorder = make_multiple_of_n( width, 16 ); + // Extra space since packing does length in multiples of 4. + dim_t length_reorder = make_multiple_of_n( length, 4 ); + + return alloc_lpgemm_obj_t_u8s8s32( length_reorder, width_reorder, stride, elem_size, stor_scheme, REORDERED ); +} + +void dealloc_lpgemm_obj_t_u8s8s32( lpgemm_obj_t* obj ) +{ + free( obj->storage.origin_buffer ); + free( obj ); +} diff --git a/addon/aocl_gemm/frame/u8s8s32/lpgemm_utils.h b/addon/aocl_gemm/frame/u8s8s32/lpgemm_utils.h new file mode 100644 index 0000000000..93acad6ac9 --- /dev/null +++ b/addon/aocl_gemm/frame/u8s8s32/lpgemm_utils.h @@ -0,0 +1,225 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef LPGEMM_UTILS_H +#define LPGEMM_UTILS_H + +#include "lpgemm_types.h" + +// Users of this API needs to free the allocated memory on their own. +dim_t get_64byte_aligned_memory + ( + void** original_memory, + void** aligned_memory, + int64_t allocate_size + ); + +lpgemm_obj_t* alloc_unpack_tag_lpgemm_obj_t_u8s8s32 + ( + dim_t length, + dim_t width, + dim_t stride, + dim_t elem_size, + AOCL_STOR_TAG stor_scheme + ); + +lpgemm_obj_t* alloc_pack_tag_lpgemm_obj_t_u8s8s32 + ( + dim_t length, + dim_t width, + dim_t stride, + dim_t elem_size, + AOCL_STOR_TAG stor_scheme + ); + +lpgemm_obj_t* alloc_reorder_tag_lpgemm_obj_t_u8s8s32 + ( + dim_t length, + dim_t width, + dim_t stride, + dim_t elem_size, + AOCL_STOR_TAG stor_scheme + ); + +void dealloc_lpgemm_obj_t_u8s8s32( lpgemm_obj_t* obj ); + +BLIS_INLINE void bli_param_map_char_to_lpmtag + ( + char mtag, + AOCL_MEMORY_TAG* lp_mtag + ) +{ + if ( mtag == 'n' || mtag == 'N' ) *lp_mtag = UNPACKED; + else if ( mtag == 'p' || mtag == 'P' ) *lp_mtag = PACK; + else if ( mtag == 'r' || mtag == 'R' ) *lp_mtag = REORDERED; + else + { + *lp_mtag = UNPACKED; + } +} + +BLIS_INLINE void bli_param_map_char_to_lpmat_type + ( + const char mtag, + AOCL_MATRIX_TYPE* lp_mat_type + ) +{ + if ( mtag == 'a' || mtag == 'A' ) *lp_mat_type = A_MATRIX; + else if ( mtag == 'b' || mtag == 'B' ) *lp_mat_type = B_MATRIX; + else + { + *lp_mat_type = B_MATRIX; + } +} + +BLIS_INLINE dim_t make_multiple_of_n( dim_t k, dim_t n ) +{ + if ( n <= 0 ) + { + return 0; + } + + return ( ( ( k + n - 1 ) / n ) * n ); +} + +BLIS_INLINE void lpgemm_alloc_mem_panel + ( + dim_t size_req, + packbuf_t buf_type, + mem_t* mem, + rntm_t* rntm_l + ) +{ + if ( bli_mem_is_unalloc( mem ) ) + { + bli_membrk_acquire_m + ( + rntm_l, + size_req, + buf_type, + mem + ); + } + else + { + siz_t mem_size = bli_mem_size( mem ); + if ( mem_size < size_req ) + { + bli_membrk_release( rntm_l, mem ); + bli_membrk_acquire_m + ( + rntm_l, + size_req, + buf_type, + mem + ); + } + } +} + +BLIS_INLINE dim_t get_Bpanel_width_for_kdim_traversal + ( + dim_t jc, + dim_t n, + dim_t NC, + dim_t NR + ) +{ + dim_t n_mod_NR = n % NR; + dim_t n_sub_updated = NC; + + if ( ( n % NC ) != 0 ) + { + // Only applicable to final NC part of jc loop where jc + remaining + // elements is less than NC; or when n < NC in which case panel width + // is atmost n. + dim_t n_last_loop = ( n / NC ) * NC; + if ( jc >= n_last_loop ) + { + n_sub_updated = n - n_last_loop; + if ( n_mod_NR != 0 ) + { + n_sub_updated += ( NR - n_mod_NR ); + } + } + } + + return n_sub_updated; +} + +BLIS_INLINE void get_B_panel_reordered_start_offset_width + ( + dim_t jc, + dim_t n, + dim_t NC, + dim_t NR, + dim_t* panel_start, + dim_t* panel_offset, + dim_t* panel_width, + dim_t* panel_width_kdim_trav + ) +{ + // Since n dimension is split across threads in units of NR blocks, + // it could happen that B matrix chunk for a thread may be part of + // two separate NCxKC panels. In this case nc0 is updated such that + // the jr loop only accesses the remaining portion of current NCxKC + // panel, with the next jc iteration taking care of the other panel. + // This ensures that jr loop does not cross panel boundaries. + ( *panel_start ) = ( jc / NC ) * NC; + ( *panel_offset ) = jc - ( *panel_start ); + + // Check if jc + current_panel_width (nc0) crosses panel boundaries. + if ( ( jc + ( *panel_width ) ) > ( ( *panel_start ) + NC ) ) + { + ( *panel_width ) = NC - ( *panel_offset ); + } + + ( *panel_width_kdim_trav ) = get_Bpanel_width_for_kdim_traversal + ( + jc, n, NC, NR + ); +} + +BLIS_INLINE void adjust_B_panel_reordered_jc( dim_t* jc, dim_t panel_start ) +{ + // Since n dimension is split across threads in units of NR blocks, + // it could happen that B matrix chunk for a thread may be part of + // two separate NCxKC panels. In this case jc is reset to immediate + // previous panel offset so that in the next iteration, the + // following panel belonging to the B chunk is accessed. This + // ensures that jr loop does not cross panel boundaries. + ( *jc ) = panel_start; +} + +#endif //LPGEMM_UTILS_H diff --git a/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_6x64rowmajor_bf16_amd512vnni.c b/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_6x64rowmajor_bf16_amd512vnni.c new file mode 100644 index 0000000000..65a4963dcb --- /dev/null +++ b/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_6x64rowmajor_bf16_amd512vnni.c @@ -0,0 +1,1146 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS dim_tERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include + +#include "blis.h" +#include "lpgemm_kernels.h" +#include "lpgemm_f32_kern_macros.h" + +#ifdef BLIS_KERNELS_ZEN4 +// 6x64 bf16 kernel +LPGEMM_MAIN_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x64) +{ + static void* post_ops_labels[] = + { + &&POST_OPS_6x64_DISABLE, + &&POST_OPS_BIAS_6x64, + &&POST_OPS_RELU_6x64, + &&POST_OPS_RELU_SCALE_6x64, + &&POST_OPS_DOWNSCALE_6x64 + }; + dim_t MR = 6; + dim_t NR = 64; + + dim_t m_full_pieces = m0 / MR; + dim_t m_full_pieces_loop_limit = m_full_pieces * MR; + dim_t m_partial_pieces = m0 % MR; + + dim_t k_full_pieces = k0 / 2; + dim_t k_partial_pieces = k0 % 2; + + int32_t a_kfringe_buf = 0; + + if ( n0 < NR ) + { + dim_t n0_rem = n0 % 16; + + // Split dim_to multiple smaller fringe kernels, so as to maximize + // vectorization. Any n0 < NR(64) can be expressed as n0 = 48 + n` + // or n0 = 32 + n` or n0 = 16 + n`, where n` < 16. + dim_t n0_48 = n0 / 48; + dim_t n0_32 = n0 / 32; + dim_t n0_16 = n0 / 16; + + // KC when not multiple of 2 will have padding to make it multiple of + // 2 in packed buffer. Also the k0 cannot be passed as the updated + // value since A matrix is not packed and requires original k0. + dim_t k0_updated = k0; + k0_updated += (k0_updated & 0x1); + + if ( n0_48 == 1 ) + { + lpgemm_rowvar_bf16bf16f32of32_6x48 + ( + m0, k0, + a, rs_a, cs_a, ps_a, + b, ( ( rs_b / 4 ) * 3 ), cs_b, + c, rs_c, + alpha, beta, + is_last_k, + post_op_c_i, post_op_c_j, + post_ops_list, rs_c_downscale + ); + + b = b + ( 48 * k0_updated ); // k0x48 packed contiguosly. + c = c + 48; + post_op_c_j += 48; + } + + else if ( n0_32 == 1 ) + { + lpgemm_rowvar_bf16bf16f32of32_6x32 + ( + m0, k0, + a, rs_a, cs_a, ps_a, + b, ( ( rs_b / 4 ) * 2 ), cs_b, + c, rs_c, + alpha, beta, + is_last_k, + post_op_c_i, post_op_c_j, + post_ops_list, rs_c_downscale + ); + + b = b + ( 32 * k0_updated ); // k0x32 packed contiguosly. + c = c + 32; + post_op_c_j += 32; + } + + else if ( n0_16 == 1 ) + { + lpgemm_rowvar_bf16bf16f32of32_6x16 + ( + m0, k0, + a, rs_a, cs_a, ps_a, + b, ( ( rs_b / 4 ) * 1 ), cs_b, + c, rs_c, + alpha, beta, + is_last_k, + post_op_c_i, post_op_c_j, + post_ops_list, rs_c_downscale + ); + + b = b + ( 16 * k0_updated ); // k0x16 packed contiguosly. + c = c + 16; + post_op_c_j += 16; + } + + if ( n0_rem > 0 ) + { + lpgemm_rowvar_bf16bf16f32of32_6xlt16 + ( + m0, k0, + a, rs_a, cs_a, ps_a, + b, ( ( rs_b / 4 ) * 1 ), cs_b, + c, rs_c, + alpha, beta, n0_rem, + is_last_k, + post_op_c_i, post_op_c_j, + post_ops_list, rs_c_downscale + ); + + // No leftover fringe after this podint. + } + return; + } + + // B matrix storage bfloat type + __m512bh b0; + __m512bh b1; + __m512bh b2; + __m512bh b3; + + // A matrix storage bfloat type + __m512bh a_bf16_0; + __m512bh a_bf16_1; + + for ( dim_t ir = 0; ir < m_full_pieces_loop_limit; ir += MR ) + { + // Registers to use for accumulating C. + __m512 c_float_0p0 = _mm512_setzero_ps(); + __m512 c_float_0p1 = _mm512_setzero_ps(); + __m512 c_float_0p2 = _mm512_setzero_ps(); + __m512 c_float_0p3 = _mm512_setzero_ps(); + + __m512 c_float_1p0 = _mm512_setzero_ps(); + __m512 c_float_1p1 = _mm512_setzero_ps(); + __m512 c_float_1p2 = _mm512_setzero_ps(); + __m512 c_float_1p3 = _mm512_setzero_ps(); + + __m512 c_float_2p0 = _mm512_setzero_ps(); + __m512 c_float_2p1 = _mm512_setzero_ps(); + __m512 c_float_2p2 = _mm512_setzero_ps(); + __m512 c_float_2p3 = _mm512_setzero_ps(); + + __m512 c_float_3p0 = _mm512_setzero_ps(); + __m512 c_float_3p1 = _mm512_setzero_ps(); + __m512 c_float_3p2 = _mm512_setzero_ps(); + __m512 c_float_3p3 = _mm512_setzero_ps(); + + __m512 c_float_4p0 = _mm512_setzero_ps(); + __m512 c_float_4p1 = _mm512_setzero_ps(); + __m512 c_float_4p2 = _mm512_setzero_ps(); + __m512 c_float_4p3 = _mm512_setzero_ps(); + + __m512 c_float_5p0 = _mm512_setzero_ps(); + __m512 c_float_5p1 = _mm512_setzero_ps(); + __m512 c_float_5p2 = _mm512_setzero_ps(); + __m512 c_float_5p3 = _mm512_setzero_ps(); + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + // The instructions are arranged in a mixed way to reduce data + // chain dependencies. + + b0 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+2] + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )(a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + b1 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 1 ) ); + b2 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 2 ) ); + b3 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 3 ) ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-63] = a[0,kr:kr+2]*b[kr:kr+2,0-63] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + + // Broadcast a[1,kr:kr+2]. + a_bf16_1 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + + c_float_0p1 = _mm512_dpbf16_ps( c_float_0p1, a_bf16_0, b1 ); + c_float_0p2 = _mm512_dpbf16_ps( c_float_0p2, a_bf16_0, b2 ); + c_float_0p3 = _mm512_dpbf16_ps( c_float_0p3, a_bf16_0, b3 ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-63] = a[1,kr:kr+2]*b[kr:kr+2,0-63] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_1, b0 ); + + // Broadcast a[2,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); + + c_float_1p1 = _mm512_dpbf16_ps( c_float_1p1, a_bf16_1, b1 ); + c_float_1p2 = _mm512_dpbf16_ps( c_float_1p2, a_bf16_1, b2 ); + c_float_1p3 = _mm512_dpbf16_ps( c_float_1p3, a_bf16_1, b3 ); + + // Perform column direction mat-mul with k = 2. + // c[2,0-63] = a[2,kr:kr+2]*b[kr:kr+2,0-63] + c_float_2p0 = _mm512_dpbf16_ps( c_float_2p0, a_bf16_0, b0 ); + + // Broadcast a[3,kr:kr+2]. + a_bf16_1 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 3 ) + ( cs_a * kr ) ) ); + + c_float_2p1 = _mm512_dpbf16_ps( c_float_2p1, a_bf16_0, b1 ); + c_float_2p2 = _mm512_dpbf16_ps( c_float_2p2, a_bf16_0, b2 ); + c_float_2p3 = _mm512_dpbf16_ps( c_float_2p3, a_bf16_0, b3 ); + + // Perform column direction mat-mul with k = 2. + // c[3,0-63] = a[3,kr:kr+2]*b[kr:kr+2,0-63] + c_float_3p0 = _mm512_dpbf16_ps( c_float_3p0, a_bf16_1, b0 ); + + // Broadcast a[4,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 4 ) + ( cs_a * kr ) ) ); + + c_float_3p1 = _mm512_dpbf16_ps( c_float_3p1, a_bf16_1, b1 ); + c_float_3p2 = _mm512_dpbf16_ps( c_float_3p2, a_bf16_1, b2 ); + c_float_3p3 = _mm512_dpbf16_ps( c_float_3p3, a_bf16_1, b3 ); + + // Perform column direction mat-mul with k = 2. + // c[4,0-63] = a[4,kr:kr+2]*b[kr:kr+2,0-63] + c_float_4p0 = _mm512_dpbf16_ps( c_float_4p0, a_bf16_0, b0 ); + + // Broadcast a[5,kr:kr+2]. + a_bf16_1 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 5 ) + ( cs_a * kr ) ) ); + + c_float_4p1 = _mm512_dpbf16_ps( c_float_4p1, a_bf16_0, b1 ); + c_float_4p2 = _mm512_dpbf16_ps( c_float_4p2, a_bf16_0, b2 ); + c_float_4p3 = _mm512_dpbf16_ps( c_float_4p3, a_bf16_0, b3 ); + + // Perform column direction mat-mul with k = 2. + // c[5,0-63] = a[5,kr:kr+2]*b[kr:kr+2,0-63] + c_float_5p0 = _mm512_dpbf16_ps( c_float_5p0, a_bf16_1, b0 ); + c_float_5p1 = _mm512_dpbf16_ps( c_float_5p1, a_bf16_1, b1 ); + c_float_5p2 = _mm512_dpbf16_ps( c_float_5p2, a_bf16_1, b2 ); + c_float_5p3 = _mm512_dpbf16_ps( c_float_5p3, a_bf16_1, b3 ); + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + b0 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+2]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( bfloat16 ) ) + ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); + + b1 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); + b2 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 2 ) ); + b3 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 3 ) ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-63] = a[0,kr:kr+2]*b[kr:kr+2,0-63] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + + // Broadcast a[1,kr:kr+2]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( bfloat16 ) ) + ); + a_bf16_1 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); + + c_float_0p1 = _mm512_dpbf16_ps( c_float_0p1, a_bf16_0, b1 ); + c_float_0p2 = _mm512_dpbf16_ps( c_float_0p2, a_bf16_0, b2 ); + c_float_0p3 = _mm512_dpbf16_ps( c_float_0p3, a_bf16_0, b3 ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-63] = a[1,kr:kr+2]*b[kr:kr+2,0-63] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_1, b0 ); + + // Broadcast a[2,kr:kr+2]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 2 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( bfloat16 ) ) + ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); + + c_float_1p1 = _mm512_dpbf16_ps( c_float_1p1, a_bf16_1, b1 ); + c_float_1p2 = _mm512_dpbf16_ps( c_float_1p2, a_bf16_1, b2 ); + c_float_1p3 = _mm512_dpbf16_ps( c_float_1p3, a_bf16_1, b3 ); + + // Perform column direction mat-mul with k = 2. + // c[2,0-63] = a[2,kr:kr+2]*b[kr:kr+2,0-63] + c_float_2p0 = _mm512_dpbf16_ps( c_float_2p0, a_bf16_0, b0 ); + + // Broadcast a[3,kr:kr+2]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 3 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( bfloat16 ) ) + ); + a_bf16_1 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); + + c_float_2p1 = _mm512_dpbf16_ps( c_float_2p1, a_bf16_0, b1 ); + c_float_2p2 = _mm512_dpbf16_ps( c_float_2p2, a_bf16_0, b2 ); + c_float_2p3 = _mm512_dpbf16_ps( c_float_2p3, a_bf16_0, b3 ); + + // Perform column direction mat-mul with k = 2. + // c[3,0-63] = a[3,kr:kr+2]*b[kr:kr+2,0-63] + c_float_3p0 = _mm512_dpbf16_ps( c_float_3p0, a_bf16_1, b0 ); + + // Broadcast a[4,kr:kr+2]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 4 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( bfloat16 ) ) + ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); + + c_float_3p1 = _mm512_dpbf16_ps( c_float_3p1, a_bf16_1, b1 ); + c_float_3p2 = _mm512_dpbf16_ps( c_float_3p2, a_bf16_1, b2 ); + c_float_3p3 = _mm512_dpbf16_ps( c_float_3p3, a_bf16_1, b3 ); + + // Perform column direction mat-mul with k = 2. + // c[4,0-63] = a[4,kr:kr+2]*b[kr:kr+2,0-63] + c_float_4p0 = _mm512_dpbf16_ps( c_float_4p0, a_bf16_0, b0 ); + + // Broadcast a[5,kr:kr+2]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 5 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( bfloat16 ) ) + ); + a_bf16_1 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); + + c_float_4p1 = _mm512_dpbf16_ps( c_float_4p1, a_bf16_0, b1 ); + c_float_4p2 = _mm512_dpbf16_ps( c_float_4p2, a_bf16_0, b2 ); + c_float_4p3 = _mm512_dpbf16_ps( c_float_4p3, a_bf16_0, b3 ); + + // Perform column direction mat-mul with k = 2. + // c[5,0-63] = a[5,kr:kr+2]*b[kr:kr+2,0-63] + c_float_5p0 = _mm512_dpbf16_ps( c_float_5p0, a_bf16_1, b0 ); + c_float_5p1 = _mm512_dpbf16_ps( c_float_5p1, a_bf16_1, b1 ); + c_float_5p2 = _mm512_dpbf16_ps( c_float_5p2, a_bf16_1, b2 ); + c_float_5p3 = _mm512_dpbf16_ps( c_float_5p3, a_bf16_1, b3 ); + } + + // Load alpha and beta + __m512 selector1 = _mm512_set1_ps ( alpha ); + __m512 selector2 = _mm512_set1_ps ( beta ); + + // Scale by alpha + c_float_0p0 = _mm512_mul_ps( selector1, c_float_0p0 ); + c_float_0p1 = _mm512_mul_ps( selector1, c_float_0p1 ); + c_float_0p2 = _mm512_mul_ps( selector1, c_float_0p2 ); + c_float_0p3 = _mm512_mul_ps( selector1, c_float_0p3 ); + + c_float_1p0 = _mm512_mul_ps( selector1, c_float_1p0 ); + c_float_1p1 = _mm512_mul_ps( selector1, c_float_1p1 ); + c_float_1p2 = _mm512_mul_ps( selector1, c_float_1p2 ); + c_float_1p3 = _mm512_mul_ps( selector1, c_float_1p3 ); + + c_float_2p0 = _mm512_mul_ps( selector1, c_float_2p0 ); + c_float_2p1 = _mm512_mul_ps( selector1, c_float_2p1 ); + c_float_2p2 = _mm512_mul_ps( selector1, c_float_2p2 ); + c_float_2p3 = _mm512_mul_ps( selector1, c_float_2p3 ); + + c_float_3p0 = _mm512_mul_ps( selector1, c_float_3p0 ); + c_float_3p1 = _mm512_mul_ps( selector1, c_float_3p1 ); + c_float_3p2 = _mm512_mul_ps( selector1, c_float_3p2 ); + c_float_3p3 = _mm512_mul_ps( selector1, c_float_3p3 ); + + c_float_4p0 = _mm512_mul_ps( selector1, c_float_4p0 ); + c_float_4p1 = _mm512_mul_ps( selector1, c_float_4p1 ); + c_float_4p2 = _mm512_mul_ps( selector1, c_float_4p2 ); + c_float_4p3 = _mm512_mul_ps( selector1, c_float_4p3 ); + + c_float_5p0 = _mm512_mul_ps( selector1, c_float_5p0 ); + c_float_5p1 = _mm512_mul_ps( selector1, c_float_5p1 ); + c_float_5p2 = _mm512_mul_ps( selector1, c_float_5p2 ); + c_float_5p3 = _mm512_mul_ps( selector1, c_float_5p3 ); + + // Scale C by beta. + if ( beta != 0 ) + { + // c[0,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 0 ) ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 0 ) ) + ( 1*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_0p1 = _mm512_add_ps( selector1, c_float_0p1 ); + + // c[0,32-47] + selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 0 ) ) + ( 2*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_0p2 = _mm512_add_ps( selector1, c_float_0p2 ); + + // c[0,48-63] + selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 0 ) ) + ( 3*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_0p3 = _mm512_add_ps( selector1, c_float_0p3 ); + + // c[1,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 1 ) ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + + // c[1,16-31] + selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 1 ) ) + ( 1*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_1p1 = _mm512_add_ps( selector1, c_float_1p1 ); + + // c[1,32-47] + selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 1 ) ) + ( 2*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_1p2 = _mm512_add_ps( selector1, c_float_1p2 ); + + // c[1,48-63] + selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 1 ) ) + ( 3*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_1p3 = _mm512_add_ps( selector1, c_float_1p3 ); + + // c[2,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 2 ) ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); + + // c[2,16-31] + selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 2 ) ) + ( 1*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_2p1 = _mm512_add_ps( selector1, c_float_2p1 ); + + // c[2,32-47] + selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 2 ) ) + ( 2*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_2p2 = _mm512_add_ps( selector1, c_float_2p2 ); + + // c[2,48-63] + selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 2 ) ) + ( 3*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_2p3 = _mm512_add_ps( selector1, c_float_2p3 ); + + // c[3,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 3 ) ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_3p0 = _mm512_add_ps( selector1, c_float_3p0 ); + + // c[3,16-31] + selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 3 ) ) + ( 1*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_3p1 = _mm512_add_ps( selector1, c_float_3p1 ); + + // c[3,32-47] + selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 3 ) ) + ( 2*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_3p2 = _mm512_add_ps( selector1, c_float_3p2 ); + + // c[3,48-63] + selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 3 ) ) + ( 3*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_3p3 = _mm512_add_ps( selector1, c_float_3p3 ); + + // c[4,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 4 ) ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_4p0 = _mm512_add_ps( selector1, c_float_4p0 ); + + // c[4,16-31] + selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 4 ) ) + ( 1*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_4p1 = _mm512_add_ps( selector1, c_float_4p1 ); + + // c[4,32-47] + selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 4 ) ) + ( 2*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_4p2 = _mm512_add_ps( selector1, c_float_4p2 ); + + // c[4,48-63] + selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 4 ) ) + ( 3*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_4p3 = _mm512_add_ps( selector1, c_float_4p3 ); + + // c[5,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 5 ) ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_5p0 = _mm512_add_ps( selector1, c_float_5p0 ); + + // c[5,16-31] + selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 5 ) ) + ( 1*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_5p1 = _mm512_add_ps( selector1, c_float_5p1 ); + + // c[5,32-47] + selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 5 ) ) + ( 2*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_5p2 = _mm512_add_ps( selector1, c_float_5p2 ); + + // c[5,48-63] + selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 5 ) ) + ( 3*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_5p3 = _mm512_add_ps( selector1, c_float_5p3 ); + } + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_6x64: + { + __m512 selector3; + __m512 selector4; + + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + selector1 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 1 * 16 ) ); + selector3 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 2 * 16 ) ); + selector4 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 3 * 16 ) ); + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_add_ps( selector2, c_float_0p1 ); + + // c[0,32-47] + c_float_0p2 = _mm512_add_ps( selector3, c_float_0p2 ); + + // c[0,48-63] + c_float_0p3 = _mm512_add_ps( selector4, c_float_0p3 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + + // c[1, 16-31] + c_float_1p1 = _mm512_add_ps( selector2, c_float_1p1 ); + + // c[1,32-47] + c_float_1p2 = _mm512_add_ps( selector3, c_float_1p2 ); + + // c[1,48-63] + c_float_1p3 = _mm512_add_ps( selector4, c_float_1p3 ); + + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); + + // c[2, 16-31] + c_float_2p1 = _mm512_add_ps( selector2, c_float_2p1 ); + + // c[2,32-47] + c_float_2p2 = _mm512_add_ps( selector3, c_float_2p2 ); + + // c[2,48-63] + c_float_2p3 = _mm512_add_ps( selector4, c_float_2p3 ); + + // c[3,0-15] + c_float_3p0 = _mm512_add_ps( selector1, c_float_3p0 ); + + // c[3, 16-31] + c_float_3p1 = _mm512_add_ps( selector2, c_float_3p1 ); + + // c[3,32-47] + c_float_3p2 = _mm512_add_ps( selector3, c_float_3p2 ); + + // c[3,48-63] + c_float_3p3 = _mm512_add_ps( selector4, c_float_3p3 ); + + // c[4,0-15] + c_float_4p0 = _mm512_add_ps( selector1, c_float_4p0 ); + + // c[4, 16-31] + c_float_4p1 = _mm512_add_ps( selector2, c_float_4p1 ); + + // c[4,32-47] + c_float_4p2 = _mm512_add_ps( selector3, c_float_4p2 ); + + // c[4,48-63] + c_float_4p3 = _mm512_add_ps( selector4, c_float_4p3 ); + + // c[5,0-15] + c_float_5p0 = _mm512_add_ps( selector1, c_float_5p0 ); + + // c[5, 16-31] + c_float_5p1 = _mm512_add_ps( selector2, c_float_5p1 ); + + // c[5,32-47] + c_float_5p2 = _mm512_add_ps( selector3, c_float_5p2 ); + + // c[5,48-63] + c_float_5p3 = _mm512_add_ps( selector4, c_float_5p3 ); + } + else + { + // If original output was columns major, then by the time + // kernel sees it, the matrix would be accessed as if it were + // transposed. Due to this the bias array will be accessed by + // the ic index, and each bias element corresponds to an + // entire row of the transposed output array, instead of an + // entire column. + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 1 ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 2 ) ); + selector4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 3 ) ); + __m512 selector5 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 4 ) ); + __m512 selector6 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 5 ) ); + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_add_ps( selector1, c_float_0p1 ); + + // c[0,32-47] + c_float_0p2 = _mm512_add_ps( selector1, c_float_0p2 ); + + // c[0,48-63] + c_float_0p3 = _mm512_add_ps( selector1, c_float_0p3 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector2, c_float_1p0 ); + + // c[1, 16-31] + c_float_1p1 = _mm512_add_ps( selector2, c_float_1p1 ); + + // c[1,32-47] + c_float_1p2 = _mm512_add_ps( selector2, c_float_1p2 ); + + // c[1,48-63] + c_float_1p3 = _mm512_add_ps( selector2, c_float_1p3 ); + + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector3, c_float_2p0 ); + + // c[2, 16-31] + c_float_2p1 = _mm512_add_ps( selector3, c_float_2p1 ); + + // c[2,32-47] + c_float_2p2 = _mm512_add_ps( selector3, c_float_2p2 ); + + // c[2,48-63] + c_float_2p3 = _mm512_add_ps( selector3, c_float_2p3 ); + + // c[3,0-15] + c_float_3p0 = _mm512_add_ps( selector4, c_float_3p0 ); + + // c[3, 16-31] + c_float_3p1 = _mm512_add_ps( selector4, c_float_3p1 ); + + // c[3,32-47] + c_float_3p2 = _mm512_add_ps( selector4, c_float_3p2 ); + + // c[3,48-63] + c_float_3p3 = _mm512_add_ps( selector4, c_float_3p3 ); + + // c[4,0-15] + c_float_4p0 = _mm512_add_ps( selector5, c_float_4p0 ); + + // c[4, 16-31] + c_float_4p1 = _mm512_add_ps( selector5, c_float_4p1 ); + + // c[4,32-47] + c_float_4p2 = _mm512_add_ps( selector5, c_float_4p2 ); + + // c[4,48-63] + c_float_4p3 = _mm512_add_ps( selector5, c_float_4p3 ); + + // c[5,0-15] + c_float_5p0 = _mm512_add_ps( selector6, c_float_5p0 ); + + // c[5, 16-31] + c_float_5p1 = _mm512_add_ps( selector6, c_float_5p1 ); + + // c[5,32-47] + c_float_5p2 = _mm512_add_ps( selector6, c_float_5p2 ); + + // c[5,48-63] + c_float_5p3 = _mm512_add_ps( selector6, c_float_5p3 ); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_6x64: + { + selector1 = _mm512_setzero_ps(); + + // c[0,0-15] + c_float_0p0 = _mm512_max_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_max_ps( selector1, c_float_0p1 ); + + // c[0,32-47] + c_float_0p2 = _mm512_max_ps( selector1, c_float_0p2 ); + + // c[0,48-63] + c_float_0p3 = _mm512_max_ps( selector1, c_float_0p3 ); + + // c[1,0-15] + c_float_1p0 = _mm512_max_ps( selector1, c_float_1p0 ); + + // c[1,16-31] + c_float_1p1 = _mm512_max_ps( selector1, c_float_1p1 ); + + // c[1,32-47] + c_float_1p2 = _mm512_max_ps( selector1, c_float_1p2 ); + + // c[1,48-63] + c_float_1p3 = _mm512_max_ps( selector1, c_float_1p3 ); + + // c[2,0-15] + c_float_2p0 = _mm512_max_ps( selector1, c_float_2p0 ); + + // c[2,16-31] + c_float_2p1 = _mm512_max_ps( selector1, c_float_2p1 ); + + // c[2,32-47] + c_float_2p2 = _mm512_max_ps( selector1, c_float_2p2 ); + + // c[2,48-63] + c_float_2p3 = _mm512_max_ps( selector1, c_float_2p3 ); + + // c[3,0-15] + c_float_3p0 = _mm512_max_ps( selector1, c_float_3p0 ); + + // c[3,16-31] + c_float_3p1 = _mm512_max_ps( selector1, c_float_3p1 ); + + // c[3,32-47] + c_float_3p2 = _mm512_max_ps( selector1, c_float_3p2 ); + + // c[3,48-63] + c_float_3p3 = _mm512_max_ps( selector1, c_float_3p3 ); + + // c[4,0-15] + c_float_4p0 = _mm512_max_ps( selector1, c_float_4p0 ); + + // c[4,16-31] + c_float_4p1 = _mm512_max_ps( selector1, c_float_4p1 ); + + // c[4,32-47] + c_float_4p2 = _mm512_max_ps( selector1, c_float_4p2 ); + + // c[4,48-63] + c_float_4p3 = _mm512_max_ps( selector1, c_float_4p3 ); + + // c[5,0-15] + c_float_5p0 = _mm512_max_ps( selector1, c_float_5p0 ); + + // c[5,16-31] + c_float_5p1 = _mm512_max_ps( selector1, c_float_5p1 ); + + // c[5,32-47] + c_float_5p2 = _mm512_max_ps( selector1, c_float_5p2 ); + + // c[5,48-63] + c_float_5p3 = _mm512_max_ps( selector1, c_float_5p3 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_6x64: + { + selector1 = _mm512_setzero_ps(); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_0p0) + + // c[0, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_0p1) + + // c[0, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_0p2) + + // c[0, 48-63] + RELU_SCALE_OP_F32_AVX512(c_float_0p3) + + // c[1, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_1p0) + + // c[1, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_1p1) + + // c[1, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_1p2) + + // c[1, 48-63] + RELU_SCALE_OP_F32_AVX512(c_float_1p3) + + // c[2, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_2p0) + + // c[2, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_2p1) + + // c[2, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_2p2) + + // c[2, 48-63] + RELU_SCALE_OP_F32_AVX512(c_float_2p3) + + // c[3, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_3p0) + + // c[3, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_3p1) + + // c[3, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_3p2) + + // c[3, 48-63] + RELU_SCALE_OP_F32_AVX512(c_float_3p3) + + // c[4, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_4p0) + + // c[4, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_4p1) + + // c[4, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_4p2) + + // c[4, 48-63] + RELU_SCALE_OP_F32_AVX512(c_float_4p3) + + // c[5, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_5p0) + + // c[5, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_5p1) + + // c[5, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_5p2) + + // c[5, 48-63] + RELU_SCALE_OP_F32_AVX512(c_float_5p3) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_DOWNSCALE_6x64: +{ + // c[0, 0-15] + CVT_F32_BF16(c_float_0p0,0,0); + + // c[0, 16-31] + CVT_F32_BF16(c_float_0p1,0,1); + + // c[0, 32-47] + CVT_F32_BF16(c_float_0p2,0,2); + + // c[0, 48-63] + CVT_F32_BF16(c_float_0p3,0,3); + + // c[1, 0-15] + CVT_F32_BF16(c_float_1p0,1,0); + + // c[1, 16-31] + CVT_F32_BF16(c_float_1p1,1,1); + + // c[1, 32-47] + CVT_F32_BF16(c_float_1p2,1,2); + + // c[1, 48-63] + CVT_F32_BF16(c_float_1p3,1,3); + + // c[2, 0-15] + CVT_F32_BF16(c_float_2p0,2,0); + + // c[2, 16-31] + CVT_F32_BF16(c_float_2p1,2,1); + + // c[2, 32-47] + CVT_F32_BF16(c_float_2p2,2,2); + + // c[2, 48-63] + CVT_F32_BF16(c_float_2p3,2,3); + + // c[3, 0-15] + CVT_F32_BF16(c_float_3p0,3,0); + + // c[3, 16-31] + CVT_F32_BF16(c_float_3p1,3,1); + + // c[3, 32-47] + CVT_F32_BF16(c_float_3p2,3,2); + + // c[3, 48-63] + CVT_F32_BF16(c_float_3p3,3,3); + + // c[4, 0-15] + CVT_F32_BF16(c_float_4p0,4,0); + + // c[4, 16-31] + CVT_F32_BF16(c_float_4p1,4,1); + + // c[4, 32-47] + CVT_F32_BF16(c_float_4p2,4,2); + + // c[4, 48-63] + CVT_F32_BF16(c_float_4p3,4,3); + + // c[5, 0-15] + CVT_F32_BF16(c_float_5p0,5,0); + + // c[5, 16-31] + CVT_F32_BF16(c_float_5p1,5,1); + + // c[5, 32-47] + CVT_F32_BF16(c_float_5p2,5,2); + + // c[5, 48-63] + CVT_F32_BF16(c_float_5p3,5,3); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR +} + +POST_OPS_6x64_DISABLE: + ; + + // Store the results. + // c[0,0-15] + _mm512_storeu_ps( c + ( rs_c * ( ir + 0 ) ) + ( 0*16 ), c_float_0p0 ); + + // c[0, 16-31] + _mm512_storeu_ps( c + ( rs_c * ( ir + 0 ) ) + ( 1*16 ), c_float_0p1 ); + + // c[0,32-47] + _mm512_storeu_ps( c + ( rs_c * ( ir + 0 ) ) + ( 2*16 ), c_float_0p2 ); + + // c[0,48-63] + _mm512_storeu_ps( c + ( rs_c * ( ir + 0 ) ) + ( 3*16 ), c_float_0p3 ); + + // c[1,0-15] + _mm512_storeu_ps( c + ( rs_c * ( ir + 1 ) ) + ( 0*16 ), c_float_1p0 ); + + // c[1,16-31] + _mm512_storeu_ps( c + ( rs_c * ( ir + 1 ) ) + ( 1*16 ), c_float_1p1 ); + + // c[1,32-47] + _mm512_storeu_ps( c + ( rs_c * ( ir + 1 ) ) + ( 2*16 ), c_float_1p2 ); + + // c[1,48-63] + _mm512_storeu_ps( c + ( rs_c * ( ir + 1 ) ) + ( 3*16 ), c_float_1p3 ); + + // c[2,0-15] + _mm512_storeu_ps( c + ( rs_c * ( ir + 2 ) ) + ( 0*16 ), c_float_2p0 ); + + // c[2,16-31] + _mm512_storeu_ps( c + ( rs_c * ( ir + 2 ) ) + ( 1*16 ), c_float_2p1 ); + + // c[2,32-47] + _mm512_storeu_ps( c + ( rs_c * ( ir + 2 ) ) + ( 2*16 ), c_float_2p2 ); + + // c[2,48-63] + _mm512_storeu_ps( c + ( rs_c * ( ir + 2 ) ) + ( 3*16 ), c_float_2p3 ); + + // c[3,0-15] + _mm512_storeu_ps( c + ( rs_c * ( ir + 3 ) ) + ( 0*16 ), c_float_3p0 ); + + // c[3,16-31] + _mm512_storeu_ps( c + ( rs_c * ( ir + 3 ) ) + ( 1*16 ), c_float_3p1 ); + + // c[3,32-47] + _mm512_storeu_ps( c + ( rs_c * ( ir + 3 ) ) + ( 2*16 ), c_float_3p2 ); + + // c[3,48-63] + _mm512_storeu_ps( c + ( rs_c * ( ir + 3 ) ) + ( 3*16 ), c_float_3p3 ); + + // c[4,0-15] + _mm512_storeu_ps( c + ( rs_c * ( ir + 4 ) ) + ( 0*16 ), c_float_4p0 ); + + // c[4,16-31] + _mm512_storeu_ps( c + ( rs_c * ( ir + 4 ) ) + ( 1*16 ), c_float_4p1 ); + + // c[4,32-47] + _mm512_storeu_ps( c + ( rs_c * ( ir + 4 ) ) + ( 2*16 ), c_float_4p2 ); + + // c[4,48-63] + _mm512_storeu_ps( c + ( rs_c * ( ir + 4 ) ) + ( 3*16 ), c_float_4p3 ); + + // c[5,0-15] + _mm512_storeu_ps( c + ( rs_c * ( ir + 5 ) ) + ( 0*16 ), c_float_5p0 ); + + // c[5,16-31] + _mm512_storeu_ps( c + ( rs_c * ( ir + 5 ) ) + ( 1*16 ), c_float_5p1 ); + + // c[5,32-47] + _mm512_storeu_ps( c + ( rs_c * ( ir + 5 ) ) + ( 2*16 ), c_float_5p2 ); + + // c[5,48-63] + _mm512_storeu_ps( c + ( rs_c * ( ir + 5 ) ) + ( 3*16 ), c_float_5p3 ); + + a = a + ( MR * ps_a ); + post_op_c_i += MR; + } + + if ( m_partial_pieces > 0 ) + { + if ( m_partial_pieces == 5 ) + { + // In cases where A matrix is packed cs_a is set to 12, since the + // next column in a given row is accessed after 2*6 elements, where + // 6 is MR and 2 elements are broadcasted each time from A (bf16). + // In fringe case, where m < MR, the next column will be after m'*2 + // elements, and subsequently following adjustment of cs_a is + // required before calling m fringe kernels. + dim_t cs_a_use = ( cs_a == 2 ) ? 2 : ( ( cs_a / 6 ) * 5 ); + lpgemm_rowvar_bf16bf16f32of32_5x64 + ( + k0, + a, rs_a, cs_a_use, + b, rs_b, cs_b, + ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, + alpha, beta, + is_last_k, + post_op_c_i, post_op_c_j, + post_ops_list, rs_c_downscale + ); + } + else if ( m_partial_pieces == 4 ) + { + dim_t cs_a_use = ( cs_a == 2 ) ? 2 : ( ( cs_a / 6 ) * 4 ); + lpgemm_rowvar_bf16bf16f32of32_4x64 + ( + k0, + a, rs_a, cs_a_use, + b, rs_b, cs_b, + ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, + alpha, beta, + is_last_k, + post_op_c_i, post_op_c_j, + post_ops_list, rs_c_downscale + ); + } + else if ( m_partial_pieces == 3 ) + { + dim_t cs_a_use = ( cs_a == 2 ) ? 2 : ( ( cs_a / 6 ) * 3 ); + lpgemm_rowvar_bf16bf16f32of32_3x64 + ( + k0, + a, rs_a, cs_a_use, + b, rs_b, cs_b, + ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, + alpha, beta, + is_last_k, + post_op_c_i, post_op_c_j, + post_ops_list, rs_c_downscale + ); + } + else if ( m_partial_pieces == 2 ) + { + dim_t cs_a_use = ( cs_a == 2 ) ? 2 : ( ( cs_a / 6 ) * 2 ); + lpgemm_rowvar_bf16bf16f32of32_2x64 + ( + k0, + a, rs_a, cs_a_use, + b, rs_b, cs_b, + ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, + alpha, beta, + is_last_k, + post_op_c_i, post_op_c_j, + post_ops_list, rs_c_downscale + ); + } + else if ( m_partial_pieces == 1 ) + { + dim_t cs_a_use = ( cs_a == 2 ) ? 2 : ( ( cs_a / 6 ) * 1 ); + lpgemm_rowvar_bf16bf16f32of32_1x64 + ( + k0, + a, rs_a, cs_a_use, + b, rs_b, cs_b, + ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, + alpha, beta, + is_last_k, + post_op_c_i, post_op_c_j, + post_ops_list, rs_c_downscale + ); + } + } +} +#endif diff --git a/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_f32_kern_macros.h b/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_f32_kern_macros.h new file mode 100644 index 0000000000..c8c2a04c91 --- /dev/null +++ b/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_f32_kern_macros.h @@ -0,0 +1,66 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ +#include "aocl_bf16_type.h" + +#ifndef LPGEMM_F32_KERN_MACROS_H +#define LPGEMM_F32_KERN_MACROS_H + +#define RELU_SCALE_OP_F32_AVX512(reg) \ + /* Generate indenx of elements <= 0.*/ \ + relu_cmp_mask = _mm512_cmple_ps_mask( reg, selector1 ); \ + \ + /* Apply scaling on for <= 0 elements.*/ \ + reg = _mm512_mask_mul_ps( reg, relu_cmp_mask, reg, selector2 ); \ + +#define CVT_F32_BF16(reg,m_ind,n_ind) \ + _mm256_storeu_epi16 \ + ( \ + ( bfloat16* )post_ops_list_temp->op_args3 + \ + ( rs_c_downscale * ( post_op_c_i + m_ind ) ) + post_op_c_j + ( n_ind * 16 ), \ + (__m256i) \ + _mm512_cvtneps_pbh( reg ) \ + ) \ + +#define CVT_F32_BF16_LT16(reg,m_ind,n_ind) \ + _mm256_storeu_epi16 \ + ( \ + buf0, \ + (__m256i) \ + _mm512_cvtneps_pbh( reg ) \ + ); \ + memcpy( ( bfloat16* )post_ops_list_temp->op_args3 + \ + ( rs_c_downscale * ( post_op_c_i + m_ind ) ) + post_op_c_j + \ + ( n_ind * 16 ) , buf0, ( n0_rem * sizeof( bfloat16 ) ) ); \ + +#endif // LPGEMM_F32_KERN_MACROS_H \ No newline at end of file diff --git a/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_m_fringe_bf16_amd512vnni.c b/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_m_fringe_bf16_amd512vnni.c new file mode 100644 index 0000000000..e4418b2a0e --- /dev/null +++ b/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_m_fringe_bf16_amd512vnni.c @@ -0,0 +1,2592 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include + +#include "blis.h" +#include "lpgemm_kernels.h" +#include "lpgemm_f32_kern_macros.h" + +#ifdef BLIS_KERNELS_ZEN4 +// 5x64 bf16 kernel +LPGEMM_M_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_5x64) +{ + static void* post_ops_labels[] = + { + &&POST_OPS_5x64_DISABLE, + &&POST_OPS_BIAS_5x64, + &&POST_OPS_RELU_5x64, + &&POST_OPS_RELU_SCALE_5x64, + &&POST_OPS_DOWNSCALE_5x64 + }; + dim_t k_full_pieces = k0 / 2; + dim_t k_partial_pieces = k0 % 2; + + int32_t a_kfringe_buf = 0; + + // B matrix storage bfloat type + __m512bh b0; + __m512bh b1; + __m512bh b2; + __m512bh b3; + + // A matrix storage bfloat type + __m512bh a_bf16_0; + __m512bh a_bf16_1; + + // Registers to use for accumulating C. + __m512 c_float_0p0 = _mm512_setzero_ps(); + __m512 c_float_0p1 = _mm512_setzero_ps(); + __m512 c_float_0p2 = _mm512_setzero_ps(); + __m512 c_float_0p3 = _mm512_setzero_ps(); + + __m512 c_float_1p0 = _mm512_setzero_ps(); + __m512 c_float_1p1 = _mm512_setzero_ps(); + __m512 c_float_1p2 = _mm512_setzero_ps(); + __m512 c_float_1p3 = _mm512_setzero_ps(); + + __m512 c_float_2p0 = _mm512_setzero_ps(); + __m512 c_float_2p1 = _mm512_setzero_ps(); + __m512 c_float_2p2 = _mm512_setzero_ps(); + __m512 c_float_2p3 = _mm512_setzero_ps(); + + __m512 c_float_3p0 = _mm512_setzero_ps(); + __m512 c_float_3p1 = _mm512_setzero_ps(); + __m512 c_float_3p2 = _mm512_setzero_ps(); + __m512 c_float_3p3 = _mm512_setzero_ps(); + + __m512 c_float_4p0 = _mm512_setzero_ps(); + __m512 c_float_4p1 = _mm512_setzero_ps(); + __m512 c_float_4p2 = _mm512_setzero_ps(); + __m512 c_float_4p3 = _mm512_setzero_ps(); + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + b0 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + b1 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 1 ) ); + b2 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 2 ) ); + b3 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 3 ) ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-63] = a[0,kr:kr+2]*b[kr:kr+2,0-63] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + + // Broadcast a[1,kr:kr+2]. + a_bf16_1 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + + c_float_0p1 = _mm512_dpbf16_ps( c_float_0p1, a_bf16_0, b1 ); + c_float_0p2 = _mm512_dpbf16_ps( c_float_0p2, a_bf16_0, b2 ); + c_float_0p3 = _mm512_dpbf16_ps( c_float_0p3, a_bf16_0, b3 ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-63] = a[1,kr:kr+2]*b[kr:kr+2,0-63] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_1, b0 ); + + // Broadcast a[2,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); + + c_float_1p1 = _mm512_dpbf16_ps( c_float_1p1, a_bf16_1, b1 ); + c_float_1p2 = _mm512_dpbf16_ps( c_float_1p2, a_bf16_1, b2 ); + c_float_1p3 = _mm512_dpbf16_ps( c_float_1p3, a_bf16_1, b3 ); + + // Perform column direction mat-mul with k = 2. + // c[2,0-63] = a[2,kr:kr+2]*b[kr:kr+2,0-63] + c_float_2p0 = _mm512_dpbf16_ps( c_float_2p0, a_bf16_0, b0 ); + + // Broadcast a[3,kr:kr+2]. + a_bf16_1 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 3 ) + ( cs_a * kr ) ) ); + + c_float_2p1 = _mm512_dpbf16_ps( c_float_2p1, a_bf16_0, b1 ); + c_float_2p2 = _mm512_dpbf16_ps( c_float_2p2, a_bf16_0, b2 ); + c_float_2p3 = _mm512_dpbf16_ps( c_float_2p3, a_bf16_0, b3 ); + + // Perform column direction mat-mul with k = 2. + // c[3,0-63] = a[3,kr:kr+2]*b[kr:kr+2,0-63] + c_float_3p0 = _mm512_dpbf16_ps( c_float_3p0, a_bf16_1, b0 ); + + // Broadcast a[4,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 4 ) + ( cs_a * kr ) ) ); + + c_float_3p1 = _mm512_dpbf16_ps( c_float_3p1, a_bf16_1, b1 ); + c_float_3p2 = _mm512_dpbf16_ps( c_float_3p2, a_bf16_1, b2 ); + c_float_3p3 = _mm512_dpbf16_ps( c_float_3p3, a_bf16_1, b3 ); + + // Perform column direction mat-mul with k = 2. + // c[4,0-63] = a[4,kr:kr+2]*b[kr:kr+2,0-63] + c_float_4p0 = _mm512_dpbf16_ps( c_float_4p0, a_bf16_0, b0 ); + c_float_4p1 = _mm512_dpbf16_ps( c_float_4p1, a_bf16_0, b1 ); + c_float_4p2 = _mm512_dpbf16_ps( c_float_4p2, a_bf16_0, b2 ); + c_float_4p3 = _mm512_dpbf16_ps( c_float_4p3, a_bf16_0, b3 ); + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + b0 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+4]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( bfloat16 ) ) + ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); + + b1 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); + b2 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 2 ) ); + b3 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 3 ) ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-63] = a[0,kr:kr+2]*b[kr:kr+2,0-63] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + + // Broadcast a[1,kr:kr+2]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( bfloat16 ) ) + ); + a_bf16_1 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); + + c_float_0p1 = _mm512_dpbf16_ps( c_float_0p1, a_bf16_0, b1 ); + c_float_0p2 = _mm512_dpbf16_ps( c_float_0p2, a_bf16_0, b2 ); + c_float_0p3 = _mm512_dpbf16_ps( c_float_0p3, a_bf16_0, b3 ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-63] = a[1,kr:kr+2]*b[kr:kr+2,0-63] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_1, b0 ); + + // Broadcast a[2,kr:kr+2]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 2 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( bfloat16 ) ) + ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); + + c_float_1p1 = _mm512_dpbf16_ps( c_float_1p1, a_bf16_1, b1 ); + c_float_1p2 = _mm512_dpbf16_ps( c_float_1p2, a_bf16_1, b2 ); + c_float_1p3 = _mm512_dpbf16_ps( c_float_1p3, a_bf16_1, b3 ); + + // Perform column direction mat-mul with k = 2. + // c[2,0-63] = a[2,kr:kr+2]*b[kr:kr+2,0-63] + c_float_2p0 = _mm512_dpbf16_ps( c_float_2p0, a_bf16_0, b0 ); + + // Broadcast a[3,kr:kr+4]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 3 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( bfloat16 ) ) + ); + a_bf16_1 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); + + c_float_2p1 = _mm512_dpbf16_ps( c_float_2p1, a_bf16_0, b1 ); + c_float_2p2 = _mm512_dpbf16_ps( c_float_2p2, a_bf16_0, b2 ); + c_float_2p3 = _mm512_dpbf16_ps( c_float_2p3, a_bf16_0, b3 ); + + // Perform column direction mat-mul with k = 2. + // c[3,0-63] = a[3,kr:kr+2]*b[kr:kr+2,0-63] + c_float_3p0 = _mm512_dpbf16_ps( c_float_3p0, a_bf16_1, b0 ); + + // Broadcast a[4,kr:kr+2]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 4 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( bfloat16 ) ) + ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); + + c_float_3p1 = _mm512_dpbf16_ps( c_float_3p1, a_bf16_1, b1 ); + c_float_3p2 = _mm512_dpbf16_ps( c_float_3p2, a_bf16_1, b2 ); + c_float_3p3 = _mm512_dpbf16_ps( c_float_3p3, a_bf16_1, b3 ); + + // Perform column direction mat-mul with k = 2. + // c[4,0-63] = a[4,kr:kr+2]*b[kr:kr+2,0-63] + c_float_4p0 = _mm512_dpbf16_ps( c_float_4p0, a_bf16_0, b0 ); + c_float_4p1 = _mm512_dpbf16_ps( c_float_4p1, a_bf16_0, b1 ); + c_float_4p2 = _mm512_dpbf16_ps( c_float_4p2, a_bf16_0, b2 ); + c_float_4p3 = _mm512_dpbf16_ps( c_float_4p3, a_bf16_0, b3 ); + } + + // Load alpha and beta + __m512 selector1 = _mm512_set1_ps( alpha ); + __m512 selector2 = _mm512_set1_ps( beta ); + + // Scale by alpha + c_float_0p0 = _mm512_mul_ps( selector1, c_float_0p0 ); + c_float_0p1 = _mm512_mul_ps( selector1, c_float_0p1 ); + c_float_0p2 = _mm512_mul_ps( selector1, c_float_0p2 ); + c_float_0p3 = _mm512_mul_ps( selector1, c_float_0p3 ); + + c_float_1p0 = _mm512_mul_ps( selector1, c_float_1p0 ); + c_float_1p1 = _mm512_mul_ps( selector1, c_float_1p1 ); + c_float_1p2 = _mm512_mul_ps( selector1, c_float_1p2 ); + c_float_1p3 = _mm512_mul_ps( selector1, c_float_1p3 ); + + c_float_2p0 = _mm512_mul_ps( selector1, c_float_2p0 ); + c_float_2p1 = _mm512_mul_ps( selector1, c_float_2p1 ); + c_float_2p2 = _mm512_mul_ps( selector1, c_float_2p2 ); + c_float_2p3 = _mm512_mul_ps( selector1, c_float_2p3 ); + + c_float_3p0 = _mm512_mul_ps( selector1, c_float_3p0 ); + c_float_3p1 = _mm512_mul_ps( selector1, c_float_3p1 ); + c_float_3p2 = _mm512_mul_ps( selector1, c_float_3p2 ); + c_float_3p3 = _mm512_mul_ps( selector1, c_float_3p3 ); + + c_float_4p0 = _mm512_mul_ps( selector1, c_float_4p0 ); + c_float_4p1 = _mm512_mul_ps( selector1, c_float_4p1 ); + c_float_4p2 = _mm512_mul_ps( selector1, c_float_4p2 ); + c_float_4p3 = _mm512_mul_ps( selector1, c_float_4p3 ); + + // Scale C by beta. + if ( beta != 0 ) + { + // c[0,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * 0 ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + selector1 = _mm512_loadu_ps( c + ( rs_c * 0 ) + ( 1*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_0p1 = _mm512_add_ps( selector1, c_float_0p1 ); + + // c[0,32-47] + selector1 = _mm512_loadu_ps( c + ( rs_c * 0 ) + ( 2*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_0p2 = _mm512_add_ps( selector1, c_float_0p2 ); + + // c[0,48-63] + selector1 = _mm512_loadu_ps( c + ( rs_c * 0 ) + ( 3*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_0p3 = _mm512_add_ps( selector1, c_float_0p3 ); + + // c[1,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * 1 ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + + // c[1,16-31] + selector1 = _mm512_loadu_ps( c + ( rs_c * 1 ) + ( 1*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_1p1 = _mm512_add_ps( selector1, c_float_1p1 ); + + // c[1,32-47] + selector1 = _mm512_loadu_ps( c + ( rs_c * 1 ) + ( 2*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_1p2 = _mm512_add_ps( selector1, c_float_1p2 ); + + // c[1,48-63] + selector1 = _mm512_loadu_ps( c + ( rs_c * 1 ) + ( 3*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_1p3 = _mm512_add_ps( selector1, c_float_1p3 ); + + // c[2,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * 2 ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); + + // c[2,16-31] + selector1 = _mm512_loadu_ps( c + ( rs_c * 2 ) + ( 1*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_2p1 = _mm512_add_ps( selector1, c_float_2p1 ); + + // c[2,32-47] + selector1 = _mm512_loadu_ps( c + ( rs_c * 2 ) + ( 2*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_2p2 = _mm512_add_ps( selector1, c_float_2p2 ); + + // c[2,48-63] + selector1 = _mm512_loadu_ps( c + ( rs_c * 2 ) + ( 3*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_2p3 = _mm512_add_ps( selector1, c_float_2p3 ); + + // c[3,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * 3 ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_3p0 = _mm512_add_ps( selector1, c_float_3p0 ); + + // c[3,16-31] + selector1 = _mm512_loadu_ps( c + ( rs_c * 3 ) + ( 1*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_3p1 = _mm512_add_ps( selector1, c_float_3p1 ); + + // c[3,32-47] + selector1 = _mm512_loadu_ps( c + ( rs_c * 3 ) + ( 2*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_3p2 = _mm512_add_ps( selector1, c_float_3p2 ); + + // c[3,48-63] + selector1 = _mm512_loadu_ps( c + ( rs_c * 3 ) + ( 3*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_3p3 = _mm512_add_ps( selector1, c_float_3p3 ); + + // c[4,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * 4 ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_4p0 = _mm512_add_ps( selector1, c_float_4p0 ); + + // c[4,16-31] + selector1 = _mm512_loadu_ps( c + ( rs_c * 4 ) + ( 1*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_4p1 = _mm512_add_ps( selector1, c_float_4p1 ); + + // c[4,32-47] + selector1 = _mm512_loadu_ps( c + ( rs_c * 4 ) + ( 2*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_4p2 = _mm512_add_ps( selector1, c_float_4p2 ); + + // c[4,48-63] + selector1 = _mm512_loadu_ps( c + ( rs_c * 4 ) + ( 3*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_4p3 = _mm512_add_ps( selector1, c_float_4p3 ); + } + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_5x64: + { + __m512 selector3; + __m512 selector4; + + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + selector1 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_op_c_j ); + selector2 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 1 * 16 ) ); + selector3 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 2 * 16 ) ); + selector4 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 3 * 16 ) ); + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_add_ps( selector2, c_float_0p1 ); + + // c[0,32-47] + c_float_0p2 = _mm512_add_ps( selector3, c_float_0p2 ); + + // c[0,48-63] + c_float_0p3 = _mm512_add_ps( selector4, c_float_0p3 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + + // c[1, 16-31] + c_float_1p1 = _mm512_add_ps( selector2, c_float_1p1 ); + + // c[1,32-47] + c_float_1p2 = _mm512_add_ps( selector3, c_float_1p2 ); + + // c[1,48-63] + c_float_1p3 = _mm512_add_ps( selector4, c_float_1p3 ); + + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); + + // c[2, 16-31] + c_float_2p1 = _mm512_add_ps( selector2, c_float_2p1 ); + + // c[2,32-47] + c_float_2p2 = _mm512_add_ps( selector3, c_float_2p2 ); + + // c[2,48-63] + c_float_2p3 = _mm512_add_ps( selector4, c_float_2p3 ); + + // c[3,0-15] + c_float_3p0 = _mm512_add_ps( selector1, c_float_3p0 ); + + // c[3, 16-31] + c_float_3p1 = _mm512_add_ps( selector2, c_float_3p1 ); + + // c[3,32-47] + c_float_3p2 = _mm512_add_ps( selector3, c_float_3p2 ); + + // c[3,48-63] + c_float_3p3 = _mm512_add_ps( selector4, c_float_3p3 ); + + // c[4,0-15] + c_float_4p0 = _mm512_add_ps( selector1, c_float_4p0 ); + + // c[4, 16-31] + c_float_4p1 = _mm512_add_ps( selector2, c_float_4p1 ); + + // c[4,32-47] + c_float_4p2 = _mm512_add_ps( selector3, c_float_4p2 ); + + // c[4,48-63] + c_float_4p3 = _mm512_add_ps( selector4, c_float_4p3 ); + } + else + { + // If original output was columns major, then by the time + // kernel sees it, the matrix would be accessed as if it were + // transposed. Due to this the bias array will be accessed by + // the ic index, and each bias element corresponds to an + // entire row of the transposed output array, instead of an + // entire column. + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 1 ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 2 ) ); + selector4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 3 ) ); + __m512 selector5 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 4 ) ); + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_add_ps( selector1, c_float_0p1 ); + + // c[0,32-47] + c_float_0p2 = _mm512_add_ps( selector1, c_float_0p2 ); + + // c[0,48-63] + c_float_0p3 = _mm512_add_ps( selector1, c_float_0p3 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector2, c_float_1p0 ); + + // c[1, 16-31] + c_float_1p1 = _mm512_add_ps( selector2, c_float_1p1 ); + + // c[1,32-47] + c_float_1p2 = _mm512_add_ps( selector2, c_float_1p2 ); + + // c[1,48-63] + c_float_1p3 = _mm512_add_ps( selector2, c_float_1p3 ); + + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector3, c_float_2p0 ); + + // c[2, 16-31] + c_float_2p1 = _mm512_add_ps( selector3, c_float_2p1 ); + + // c[2,32-47] + c_float_2p2 = _mm512_add_ps( selector3, c_float_2p2 ); + + // c[2,48-63] + c_float_2p3 = _mm512_add_ps( selector3, c_float_2p3 ); + + // c[3,0-15] + c_float_3p0 = _mm512_add_ps( selector4, c_float_3p0 ); + + // c[3, 16-31] + c_float_3p1 = _mm512_add_ps( selector4, c_float_3p1 ); + + // c[3,32-47] + c_float_3p2 = _mm512_add_ps( selector4, c_float_3p2 ); + + // c[3,48-63] + c_float_3p3 = _mm512_add_ps( selector4, c_float_3p3 ); + + // c[4,0-15] + c_float_4p0 = _mm512_add_ps( selector5, c_float_4p0 ); + + // c[4, 16-31] + c_float_4p1 = _mm512_add_ps( selector5, c_float_4p1 ); + + // c[4,32-47] + c_float_4p2 = _mm512_add_ps( selector5, c_float_4p2 ); + + // c[4,48-63] + c_float_4p3 = _mm512_add_ps( selector5, c_float_4p3 ); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_5x64: + { + selector1 = _mm512_setzero_ps(); + + // c[0,0-15] + c_float_0p0 = _mm512_max_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_max_ps( selector1, c_float_0p1 ); + + // c[0,32-47] + c_float_0p2 = _mm512_max_ps( selector1, c_float_0p2 ); + + // c[0,48-63] + c_float_0p3 = _mm512_max_ps( selector1, c_float_0p3 ); + + // c[1,0-15] + c_float_1p0 = _mm512_max_ps( selector1, c_float_1p0 ); + + // c[1,16-31] + c_float_1p1 = _mm512_max_ps( selector1, c_float_1p1 ); + + // c[1,32-47] + c_float_1p2 = _mm512_max_ps( selector1, c_float_1p2 ); + + // c[1,48-63] + c_float_1p3 = _mm512_max_ps( selector1, c_float_1p3 ); + + // c[2,0-15] + c_float_2p0 = _mm512_max_ps( selector1, c_float_2p0 ); + + // c[2,16-31] + c_float_2p1 = _mm512_max_ps( selector1, c_float_2p1 ); + + // c[2,32-47] + c_float_2p2 = _mm512_max_ps( selector1, c_float_2p2 ); + + // c[2,48-63] + c_float_2p3 = _mm512_max_ps( selector1, c_float_2p3 ); + + // c[3,0-15] + c_float_3p0 = _mm512_max_ps( selector1, c_float_3p0 ); + + // c[3,16-31] + c_float_3p1 = _mm512_max_ps( selector1, c_float_3p1 ); + + // c[3,32-47] + c_float_3p2 = _mm512_max_ps( selector1, c_float_3p2 ); + + // c[3,48-63] + c_float_3p3 = _mm512_max_ps( selector1, c_float_3p3 ); + + // c[4,0-15] + c_float_4p0 = _mm512_max_ps( selector1, c_float_4p0 ); + + // c[4,16-31] + c_float_4p1 = _mm512_max_ps( selector1, c_float_4p1 ); + + // c[4,32-47] + c_float_4p2 = _mm512_max_ps( selector1, c_float_4p2 ); + + // c[4,48-63] + c_float_4p3 = _mm512_max_ps( selector1, c_float_4p3 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_5x64: + { + selector1 = _mm512_setzero_ps(); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_0p0) + + // c[0, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_0p1) + + // c[0, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_0p2) + + // c[0, 48-63] + RELU_SCALE_OP_F32_AVX512(c_float_0p3) + + // c[1, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_1p0) + + // c[1, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_1p1) + + // c[1, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_1p2) + + // c[1, 48-63] + RELU_SCALE_OP_F32_AVX512(c_float_1p3) + + // c[2, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_2p0) + + // c[2, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_2p1) + + // c[2, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_2p2) + + // c[2, 48-63] + RELU_SCALE_OP_F32_AVX512(c_float_2p3) + + // c[3, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_3p0) + + // c[3, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_3p1) + + // c[3, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_3p2) + + // c[3, 48-63] + RELU_SCALE_OP_F32_AVX512(c_float_3p3) + + // c[4, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_4p0) + + // c[4, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_4p1) + + // c[4, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_4p2) + + // c[4, 48-63] + RELU_SCALE_OP_F32_AVX512(c_float_4p3) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_DOWNSCALE_5x64: + { + // c[0, 0-15] + CVT_F32_BF16(c_float_0p0,0,0); + + // c[0, 16-31] + CVT_F32_BF16(c_float_0p1,0,1); + + // c[0, 32-47] + CVT_F32_BF16(c_float_0p2,0,2); + + // c[0, 48-63] + CVT_F32_BF16(c_float_0p3,0,3); + + // c[1, 0-15] + CVT_F32_BF16(c_float_1p0,1,0); + + // c[1, 16-31] + CVT_F32_BF16(c_float_1p1,1,1); + + // c[1, 32-47] + CVT_F32_BF16(c_float_1p2,1,2); + + // c[1, 48-63] + CVT_F32_BF16(c_float_1p3,1,3); + + // c[2, 0-15] + CVT_F32_BF16(c_float_2p0,2,0); + + // c[2, 16-31] + CVT_F32_BF16(c_float_2p1,2,1); + + // c[2, 32-47] + CVT_F32_BF16(c_float_2p2,2,2); + + // c[2, 48-63] + CVT_F32_BF16(c_float_2p3,2,3); + + // c[3, 0-15] + CVT_F32_BF16(c_float_3p0,3,0); + + // c[3, 16-31] + CVT_F32_BF16(c_float_3p1,3,1); + + // c[3, 32-47] + CVT_F32_BF16(c_float_3p2,3,2); + + // c[3, 48-63] + CVT_F32_BF16(c_float_3p3,3,3); + + // c[4, 0-15] + CVT_F32_BF16(c_float_4p0,4,0); + + // c[4, 16-31] + CVT_F32_BF16(c_float_4p1,4,1); + + // c[4, 32-47] + CVT_F32_BF16(c_float_4p2,4,2); + + // c[4, 48-63] + CVT_F32_BF16(c_float_4p3,4,3); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_5x64_DISABLE: + ; + + // Store the results. + // c[0,0-15] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 0*16 ), c_float_0p0 ); + + // c[0, 16-31] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 1*16 ), c_float_0p1 ); + + // c[0,32-47] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 2*16 ), c_float_0p2 ); + + // c[0,48-63] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 3*16 ), c_float_0p3 ); + + // c[1,0-15] + _mm512_storeu_ps( c + ( rs_c * 1 ) + ( 0*16 ), c_float_1p0 ); + + // c[1,16-31] + _mm512_storeu_ps( c + ( rs_c * 1 ) + ( 1*16 ), c_float_1p1 ); + + // c[1,32-47] + _mm512_storeu_ps( c + ( rs_c * 1 ) + ( 2*16 ), c_float_1p2 ); + + // c[1,48-63] + _mm512_storeu_ps( c + ( rs_c * 1 ) + ( 3*16 ), c_float_1p3 ); + + // c[2,0-15] + _mm512_storeu_ps( c + ( rs_c * 2 ) + ( 0*16 ), c_float_2p0 ); + + // c[2,16-31] + _mm512_storeu_ps( c + ( rs_c * 2 ) + ( 1*16 ), c_float_2p1 ); + + // c[2,32-47] + _mm512_storeu_ps( c + ( rs_c * 2 ) + ( 2*16 ), c_float_2p2 ); + + // c[2,48-63] + _mm512_storeu_ps( c + ( rs_c * 2 ) + ( 3*16 ), c_float_2p3 ); + + // c[3,0-15] + _mm512_storeu_ps( c + ( rs_c * 3 ) + ( 0*16 ), c_float_3p0 ); + + // c[3,16-31] + _mm512_storeu_ps( c + ( rs_c * 3 ) + ( 1*16 ), c_float_3p1 ); + + // c[3,32-47] + _mm512_storeu_ps( c + ( rs_c * 3 ) + ( 2*16 ), c_float_3p2 ); + + // c[3,48-63] + _mm512_storeu_ps( c + ( rs_c * 3 ) + ( 3*16 ), c_float_3p3 ); + + // c[4,0-15] + _mm512_storeu_ps( c + ( rs_c * 4 ) + ( 0*16 ), c_float_4p0 ); + + // c[4,16-31] + _mm512_storeu_ps( c + ( rs_c * 4 ) + ( 1*16 ), c_float_4p1 ); + + // c[4,32-47] + _mm512_storeu_ps( c + ( rs_c * 4 ) + ( 2*16 ), c_float_4p2 ); + + // c[4,48-63] + _mm512_storeu_ps( c + ( rs_c * 4 ) + ( 3*16 ), c_float_4p3 ); +} + +// 4x64 bf16 kernel +LPGEMM_M_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_4x64) +{ + static void* post_ops_labels[] = + { + &&POST_OPS_4x64_DISABLE, + &&POST_OPS_BIAS_4x64, + &&POST_OPS_RELU_4x64, + &&POST_OPS_RELU_SCALE_4x64, + &&POST_OPS_DOWNSCALE_4x64 + }; + dim_t k_full_pieces = k0 / 2; + dim_t k_partial_pieces = k0 % 2; + + int32_t a_kfringe_buf = 0; + + // B matrix storage bfloat type + __m512bh b0; + __m512bh b1; + __m512bh b2; + __m512bh b3; + + // A matrix storage bfloat type + __m512bh a_bf16_0; + __m512bh a_bf16_1; + + // Registers to use for accumulating C. + __m512 c_float_0p0 = _mm512_setzero_ps(); + __m512 c_float_0p1 = _mm512_setzero_ps(); + __m512 c_float_0p2 = _mm512_setzero_ps(); + __m512 c_float_0p3 = _mm512_setzero_ps(); + + __m512 c_float_1p0 = _mm512_setzero_ps(); + __m512 c_float_1p1 = _mm512_setzero_ps(); + __m512 c_float_1p2 = _mm512_setzero_ps(); + __m512 c_float_1p3 = _mm512_setzero_ps(); + + __m512 c_float_2p0 = _mm512_setzero_ps(); + __m512 c_float_2p1 = _mm512_setzero_ps(); + __m512 c_float_2p2 = _mm512_setzero_ps(); + __m512 c_float_2p3 = _mm512_setzero_ps(); + + __m512 c_float_3p0 = _mm512_setzero_ps(); + __m512 c_float_3p1 = _mm512_setzero_ps(); + __m512 c_float_3p2 = _mm512_setzero_ps(); + __m512 c_float_3p3 = _mm512_setzero_ps(); + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + b0 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + b1 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 1 ) ); + b2 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 2 ) ); + b3 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 3 ) ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-63] = a[0,kr:kr+4]*b[kr:kr+4,0-63] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + + // Broadcast a[1,kr:kr+2]. + a_bf16_1 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + + c_float_0p1 = _mm512_dpbf16_ps( c_float_0p1, a_bf16_0, b1 ); + c_float_0p2 = _mm512_dpbf16_ps( c_float_0p2, a_bf16_0, b2 ); + c_float_0p3 = _mm512_dpbf16_ps( c_float_0p3, a_bf16_0, b3 ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-63] = a[1,kr:kr+2]*b[kr:kr+2,0-63] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_1, b0 ); + + // Broadcast a[2,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); + + c_float_1p1 = _mm512_dpbf16_ps( c_float_1p1, a_bf16_1, b1 ); + c_float_1p2 = _mm512_dpbf16_ps( c_float_1p2, a_bf16_1, b2 ); + c_float_1p3 = _mm512_dpbf16_ps( c_float_1p3, a_bf16_1, b3 ); + + // Perform column direction mat-mul with k = 2. + // c[2,0-63] = a[2,kr:kr+2]*b[kr:kr+2,0-63] + c_float_2p0 = _mm512_dpbf16_ps( c_float_2p0, a_bf16_0, b0 ); + + // Broadcast a[3,kr:kr+2]. + a_bf16_1 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 3 ) + ( cs_a * kr ) ) ); + + c_float_2p1 = _mm512_dpbf16_ps( c_float_2p1, a_bf16_0, b1 ); + c_float_2p2 = _mm512_dpbf16_ps( c_float_2p2, a_bf16_0, b2 ); + c_float_2p3 = _mm512_dpbf16_ps( c_float_2p3, a_bf16_0, b3 ); + + // Perform column direction mat-mul with k = 2. + // c[3,0-63] = a[3,kr:kr+2]*b[kr:kr+2,0-63] + c_float_3p0 = _mm512_dpbf16_ps( c_float_3p0, a_bf16_1, b0 ); + c_float_3p1 = _mm512_dpbf16_ps( c_float_3p1, a_bf16_1, b1 ); + c_float_3p2 = _mm512_dpbf16_ps( c_float_3p2, a_bf16_1, b2 ); + c_float_3p3 = _mm512_dpbf16_ps( c_float_3p3, a_bf16_1, b3 ); + } + + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + b0 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+2]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( bfloat16) ) + ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); + + b1 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); + b2 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 2 ) ); + b3 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 3 ) ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-63] = a[0,kr:kr+2]*b[kr:kr+2,0-63] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + + // Broadcast a[1,kr:kr+2]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( bfloat16 ) ) + ); + a_bf16_1 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); + + c_float_0p1 = _mm512_dpbf16_ps( c_float_0p1, a_bf16_0, b1 ); + c_float_0p2 = _mm512_dpbf16_ps( c_float_0p2, a_bf16_0, b2 ); + c_float_0p3 = _mm512_dpbf16_ps( c_float_0p3, a_bf16_0, b3 ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-63] = a[1,kr:kr+2]*b[kr:kr+2,0-63] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_1, b0 ); + + // Broadcast a[2,kr:kr+2]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 2 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( bfloat16 ) ) + ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); + + c_float_1p1 = _mm512_dpbf16_ps( c_float_1p1, a_bf16_1, b1 ); + c_float_1p2 = _mm512_dpbf16_ps( c_float_1p2, a_bf16_1, b2 ); + c_float_1p3 = _mm512_dpbf16_ps( c_float_1p3, a_bf16_1, b3 ); + + // Perform column direction mat-mul with k = 2. + // c[2,0-63] = a[2,kr:kr+2]*b[kr:kr+2,0-63] + c_float_2p0 = _mm512_dpbf16_ps( c_float_2p0, a_bf16_0, b0 ); + + // Broadcast a[3,kr:kr+2]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 3 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( bfloat16 ) ) + ); + a_bf16_1 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); + + c_float_2p1 = _mm512_dpbf16_ps( c_float_2p1, a_bf16_0, b1 ); + c_float_2p2 = _mm512_dpbf16_ps( c_float_2p2, a_bf16_0, b2 ); + c_float_2p3 = _mm512_dpbf16_ps( c_float_2p3, a_bf16_0, b3 ); + + // Perform column direction mat-mul with k = 2. + // c[3,0-63] = a[3,kr:kr+2]*b[kr:kr+2,0-63] + c_float_3p0 = _mm512_dpbf16_ps( c_float_3p0, a_bf16_1, b0 ); + c_float_3p1 = _mm512_dpbf16_ps( c_float_3p1, a_bf16_1, b1 ); + c_float_3p2 = _mm512_dpbf16_ps( c_float_3p2, a_bf16_1, b2 ); + c_float_3p3 = _mm512_dpbf16_ps( c_float_3p3, a_bf16_1, b3 ); + } + + // Load alpha and beta + __m512 selector1 = _mm512_set1_ps( alpha ); + __m512 selector2 = _mm512_set1_ps( beta ); + + // Scale by alpha + c_float_0p0 = _mm512_mul_ps( selector1, c_float_0p0 ); + c_float_0p1 = _mm512_mul_ps( selector1, c_float_0p1 ); + c_float_0p2 = _mm512_mul_ps( selector1, c_float_0p2 ); + c_float_0p3 = _mm512_mul_ps( selector1, c_float_0p3 ); + + c_float_1p0 = _mm512_mul_ps( selector1, c_float_1p0 ); + c_float_1p1 = _mm512_mul_ps( selector1, c_float_1p1 ); + c_float_1p2 = _mm512_mul_ps( selector1, c_float_1p2 ); + c_float_1p3 = _mm512_mul_ps( selector1, c_float_1p3 ); + + c_float_2p0 = _mm512_mul_ps( selector1, c_float_2p0 ); + c_float_2p1 = _mm512_mul_ps( selector1, c_float_2p1 ); + c_float_2p2 = _mm512_mul_ps( selector1, c_float_2p2 ); + c_float_2p3 = _mm512_mul_ps( selector1, c_float_2p3 ); + + c_float_3p0 = _mm512_mul_ps( selector1, c_float_3p0 ); + c_float_3p1 = _mm512_mul_ps( selector1, c_float_3p1 ); + c_float_3p2 = _mm512_mul_ps( selector1, c_float_3p2 ); + c_float_3p3 = _mm512_mul_ps( selector1, c_float_3p3 ); + + // Scale C by beta. + if ( beta != 0 ) + { + // c[0,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * 0 ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + selector1 = _mm512_loadu_ps( c + ( rs_c * 0 ) + ( 1*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_0p1 = _mm512_add_ps( selector1, c_float_0p1 ); + + // c[0,32-47] + selector1 = _mm512_loadu_ps( c + ( rs_c * 0 ) + ( 2*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_0p2 = _mm512_add_ps( selector1, c_float_0p2 ); + + // c[0,48-63] + selector1 = _mm512_loadu_ps( c + ( rs_c * 0 ) + ( 3*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_0p3 = _mm512_add_ps( selector1, c_float_0p3 ); + + // c[1,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * 1 ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + + // c[1,16-31] + selector1 = _mm512_loadu_ps( c + ( rs_c * 1 ) + ( 1*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_1p1 = _mm512_add_ps( selector1, c_float_1p1 ); + + // c[1,32-47] + selector1 = _mm512_loadu_ps( c + ( rs_c * 1 ) + ( 2*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_1p2 = _mm512_add_ps( selector1, c_float_1p2 ); + + // c[1,48-63] + selector1 = _mm512_loadu_ps( c + ( rs_c * 1 ) + ( 3*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_1p3 = _mm512_add_ps( selector1, c_float_1p3 ); + + // c[2,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * 2 ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); + + // c[2,16-31] + selector1 = _mm512_loadu_ps( c + ( rs_c * 2 ) + ( 1*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_2p1 = _mm512_add_ps( selector1, c_float_2p1 ); + + // c[2,32-47] + selector1 = _mm512_loadu_ps( c + ( rs_c * 2 ) + ( 2*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_2p2 = _mm512_add_ps( selector1, c_float_2p2 ); + + // c[2,48-63] + selector1 = _mm512_loadu_ps( c + ( rs_c * 2 ) + ( 3*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_2p3 = _mm512_add_ps( selector1, c_float_2p3 ); + + // c[3,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * 3 ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_3p0 = _mm512_add_ps( selector1, c_float_3p0 ); + + // c[3,16-31] + selector1 = _mm512_loadu_ps( c + ( rs_c * 3 ) + ( 1*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_3p1 = _mm512_add_ps( selector1, c_float_3p1 ); + + // c[3,32-47] + selector1 = _mm512_loadu_ps( c + ( rs_c * 3 ) + ( 2*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_3p2 = _mm512_add_ps( selector1, c_float_3p2 ); + + // c[3,48-63] + selector1 = _mm512_loadu_ps( c + ( rs_c * 3 ) + ( 3*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_3p3 = _mm512_add_ps( selector1, c_float_3p3 ); + } + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_4x64: + { + __m512 selector3; + __m512 selector4; + + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + selector1 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_op_c_j ); + selector2 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 1 * 16 ) ); + selector3 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 2 * 16 ) ); + selector4 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 3 * 16 ) ); + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_add_ps( selector2, c_float_0p1 ); + + // c[0,32-47] + c_float_0p2 = _mm512_add_ps( selector3, c_float_0p2 ); + + // c[0,48-63] + c_float_0p3 = _mm512_add_ps( selector4, c_float_0p3 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + + // c[1, 16-31] + c_float_1p1 = _mm512_add_ps( selector2, c_float_1p1 ); + + // c[1,32-47] + c_float_1p2 = _mm512_add_ps( selector3, c_float_1p2 ); + + // c[1,48-63] + c_float_1p3 = _mm512_add_ps( selector4, c_float_1p3 ); + + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); + + // c[2, 16-31] + c_float_2p1 = _mm512_add_ps( selector2, c_float_2p1 ); + + // c[2,32-47] + c_float_2p2 = _mm512_add_ps( selector3, c_float_2p2 ); + + // c[2,48-63] + c_float_2p3 = _mm512_add_ps( selector4, c_float_2p3 ); + + // c[3,0-15] + c_float_3p0 = _mm512_add_ps( selector1, c_float_3p0 ); + + // c[3, 16-31] + c_float_3p1 = _mm512_add_ps( selector2, c_float_3p1 ); + + // c[3,32-47] + c_float_3p2 = _mm512_add_ps( selector3, c_float_3p2 ); + + // c[3,48-63] + c_float_3p3 = _mm512_add_ps( selector4, c_float_3p3 ); + } + else + { + // If original output was columns major, then by the time + // kernel sees it, the matrix would be accessed as if it were + // transposed. Due to this the bias array will be accessed by + // the ic index, and each bias element corresponds to an + // entire row of the transposed output array, instead of an + // entire column. + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 1 ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 2 ) ); + selector4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 3 ) ); + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_add_ps( selector1, c_float_0p1 ); + + // c[0,32-47] + c_float_0p2 = _mm512_add_ps( selector1, c_float_0p2 ); + + // c[0,48-63] + c_float_0p3 = _mm512_add_ps( selector1, c_float_0p3 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector2, c_float_1p0 ); + + // c[1, 16-31] + c_float_1p1 = _mm512_add_ps( selector2, c_float_1p1 ); + + // c[1,32-47] + c_float_1p2 = _mm512_add_ps( selector2, c_float_1p2 ); + + // c[1,48-63] + c_float_1p3 = _mm512_add_ps( selector2, c_float_1p3 ); + + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector3, c_float_2p0 ); + + // c[2, 16-31] + c_float_2p1 = _mm512_add_ps( selector3, c_float_2p1 ); + + // c[2,32-47] + c_float_2p2 = _mm512_add_ps( selector3, c_float_2p2 ); + + // c[2,48-63] + c_float_2p3 = _mm512_add_ps( selector3, c_float_2p3 ); + + // c[3,0-15] + c_float_3p0 = _mm512_add_ps( selector4, c_float_3p0 ); + + // c[3, 16-31] + c_float_3p1 = _mm512_add_ps( selector4, c_float_3p1 ); + + // c[3,32-47] + c_float_3p2 = _mm512_add_ps( selector4, c_float_3p2 ); + + // c[3,48-63] + c_float_3p3 = _mm512_add_ps( selector4, c_float_3p3 ); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_4x64: + { + selector1 = _mm512_setzero_ps(); + + // c[0,0-15] + c_float_0p0 = _mm512_max_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_max_ps( selector1, c_float_0p1 ); + + // c[0,32-47] + c_float_0p2 = _mm512_max_ps( selector1, c_float_0p2 ); + + // c[0,48-63] + c_float_0p3 = _mm512_max_ps( selector1, c_float_0p3 ); + + // c[1,0-15] + c_float_1p0 = _mm512_max_ps( selector1, c_float_1p0 ); + + // c[1,16-31] + c_float_1p1 = _mm512_max_ps( selector1, c_float_1p1 ); + + // c[1,32-47] + c_float_1p2 = _mm512_max_ps( selector1, c_float_1p2 ); + + // c[1,48-63] + c_float_1p3 = _mm512_max_ps( selector1, c_float_1p3 ); + + // c[2,0-15] + c_float_2p0 = _mm512_max_ps( selector1, c_float_2p0 ); + + // c[2,16-31] + c_float_2p1 = _mm512_max_ps( selector1, c_float_2p1 ); + + // c[2,32-47] + c_float_2p2 = _mm512_max_ps( selector1, c_float_2p2 ); + + // c[2,48-63] + c_float_2p3 = _mm512_max_ps( selector1, c_float_2p3 ); + + // c[3,0-15] + c_float_3p0 = _mm512_max_ps( selector1, c_float_3p0 ); + + // c[3,16-31] + c_float_3p1 = _mm512_max_ps( selector1, c_float_3p1 ); + + // c[3,32-47] + c_float_3p2 = _mm512_max_ps( selector1, c_float_3p2 ); + + // c[3,48-63] + c_float_3p3 = _mm512_max_ps( selector1, c_float_3p3 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_4x64: + { + selector1 = _mm512_setzero_ps(); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_0p0) + + // c[0, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_0p1) + + // c[0, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_0p2) + + // c[0, 48-63] + RELU_SCALE_OP_F32_AVX512(c_float_0p3) + + // c[1, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_1p0) + + // c[1, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_1p1) + + // c[1, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_1p2) + + // c[1, 48-63] + RELU_SCALE_OP_F32_AVX512(c_float_1p3) + + // c[2, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_2p0) + + // c[2, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_2p1) + + // c[2, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_2p2) + + // c[2, 48-63] + RELU_SCALE_OP_F32_AVX512(c_float_2p3) + + // c[3, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_3p0) + + // c[3, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_3p1) + + // c[3, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_3p2) + + // c[3, 48-63] + RELU_SCALE_OP_F32_AVX512(c_float_3p3) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_DOWNSCALE_4x64: + { + // c[0, 0-15] + CVT_F32_BF16(c_float_0p0,0,0); + + // c[0, 16-31] + CVT_F32_BF16(c_float_0p1,0,1); + + // c[0, 32-47] + CVT_F32_BF16(c_float_0p2,0,2); + + // c[0, 48-63] + CVT_F32_BF16(c_float_0p3,0,3); + + // c[1, 0-15] + CVT_F32_BF16(c_float_1p0,1,0); + + // c[1, 16-31] + CVT_F32_BF16(c_float_1p1,1,1); + + // c[1, 32-47] + CVT_F32_BF16(c_float_1p2,1,2); + + // c[1, 48-63] + CVT_F32_BF16(c_float_1p3,1,3); + + // c[2, 0-15] + CVT_F32_BF16(c_float_2p0,2,0); + + // c[2, 16-31] + CVT_F32_BF16(c_float_2p1,2,1); + + // c[2, 32-47] + CVT_F32_BF16(c_float_2p2,2,2); + + // c[2, 48-63] + CVT_F32_BF16(c_float_2p3,2,3); + + // c[3, 0-15] + CVT_F32_BF16(c_float_3p0,3,0); + + // c[3, 16-31] + CVT_F32_BF16(c_float_3p1,3,1); + + // c[3, 32-47] + CVT_F32_BF16(c_float_3p2,3,2); + + // c[3, 48-63] + CVT_F32_BF16(c_float_3p3,3,3); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } + +POST_OPS_4x64_DISABLE: + ; + + // Store the results. + // c[0,0-15] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 0*16 ), c_float_0p0 ); + + // c[0, 16-31] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 1*16 ), c_float_0p1 ); + + // c[0,32-47] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 2*16 ), c_float_0p2 ); + + // c[0,48-63] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 3*16 ), c_float_0p3 ); + + // c[1,0-15] + _mm512_storeu_ps( c + ( rs_c * 1 ) + ( 0*16 ), c_float_1p0 ); + + // c[1,16-31] + _mm512_storeu_ps( c + ( rs_c * 1 ) + ( 1*16 ), c_float_1p1 ); + + // c[1,32-47] + _mm512_storeu_ps( c + ( rs_c * 1 ) + ( 2*16 ), c_float_1p2 ); + + // c[1,48-63] + _mm512_storeu_ps( c + ( rs_c * 1 ) + ( 3*16 ), c_float_1p3 ); + + // c[2,0-15] + _mm512_storeu_ps( c + ( rs_c * 2 ) + ( 0*16 ), c_float_2p0 ); + + // c[2,16-31] + _mm512_storeu_ps( c + ( rs_c * 2 ) + ( 1*16 ), c_float_2p1 ); + + // c[2,32-47] + _mm512_storeu_ps( c + ( rs_c * 2 ) + ( 2*16 ), c_float_2p2 ); + + // c[2,48-63] + _mm512_storeu_ps( c + ( rs_c * 2 ) + ( 3*16 ), c_float_2p3 ); + + // c[3,0-15] + _mm512_storeu_ps( c + ( rs_c * 3 ) + ( 0*16 ), c_float_3p0 ); + + // c[3,16-31] + _mm512_storeu_ps( c + ( rs_c * 3 ) + ( 1*16 ), c_float_3p1 ); + + // c[3,32-47] + _mm512_storeu_ps( c + ( rs_c * 3 ) + ( 2*16 ), c_float_3p2 ); + + // c[3,48-63] + _mm512_storeu_ps( c + ( rs_c * 3 ) + ( 3*16 ), c_float_3p3 ); +} + +// 3x64 bf16 kernel +LPGEMM_M_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_3x64) +{ + static void* post_ops_labels[] = + { + &&POST_OPS_3x64_DISABLE, + &&POST_OPS_BIAS_3x64, + &&POST_OPS_RELU_3x64, + &&POST_OPS_RELU_SCALE_3x64, + &&POST_OPS_DOWNSCALE_3x64 + }; + dim_t k_full_pieces = k0 / 2; + dim_t k_partial_pieces = k0 % 2; + + int32_t a_kfringe_buf = 0; + + // B matrix storage bfloat type + __m512bh b0; + __m512bh b1; + __m512bh b2; + __m512bh b3; + + // A matrix storage bfloat type + __m512bh a_bf16_0; + __m512bh a_bf16_1; + + // Registers to use for accumulating C. + __m512 c_float_0p0 = _mm512_setzero_ps(); + __m512 c_float_0p1 = _mm512_setzero_ps(); + __m512 c_float_0p2 = _mm512_setzero_ps(); + __m512 c_float_0p3 = _mm512_setzero_ps(); + + __m512 c_float_1p0 = _mm512_setzero_ps(); + __m512 c_float_1p1 = _mm512_setzero_ps(); + __m512 c_float_1p2 = _mm512_setzero_ps(); + __m512 c_float_1p3 = _mm512_setzero_ps(); + + __m512 c_float_2p0 = _mm512_setzero_ps(); + __m512 c_float_2p1 = _mm512_setzero_ps(); + __m512 c_float_2p2 = _mm512_setzero_ps(); + __m512 c_float_2p3 = _mm512_setzero_ps(); + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + b0 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + b1 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 1 ) ); + b2 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 2 ) ); + b3 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 3 ) ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-63] = a[0,kr:kr+2]*b[kr:kr+2,0-63] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + + // Broadcast a[1,kr:kr+2]. + a_bf16_1 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + + c_float_0p1 = _mm512_dpbf16_ps( c_float_0p1, a_bf16_0, b1 ); + c_float_0p2 = _mm512_dpbf16_ps( c_float_0p2, a_bf16_0, b2 ); + c_float_0p3 = _mm512_dpbf16_ps( c_float_0p3, a_bf16_0, b3 ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-63] = a[1,kr:kr+2]*b[kr:kr+2,0-63] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_1, b0 ); + + // Broadcast a[2,kr:kr+4]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); + + c_float_1p1 = _mm512_dpbf16_ps( c_float_1p1, a_bf16_1, b1 ); + c_float_1p2 = _mm512_dpbf16_ps( c_float_1p2, a_bf16_1, b2 ); + c_float_1p3 = _mm512_dpbf16_ps( c_float_1p3, a_bf16_1, b3 ); + + // Perform column direction mat-mul with k = 2. + // c[2,0-63] = a[2,kr:kr+2]*b[kr:kr+2,0-63] + c_float_2p0 = _mm512_dpbf16_ps( c_float_2p0, a_bf16_0, b0 ); + c_float_2p1 = _mm512_dpbf16_ps( c_float_2p1, a_bf16_0, b1 ); + c_float_2p2 = _mm512_dpbf16_ps( c_float_2p2, a_bf16_0, b2 ); + c_float_2p3 = _mm512_dpbf16_ps( c_float_2p3, a_bf16_0, b3 ); + } + + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + b0 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+2]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( bfloat16 ) ) + ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); + + b1 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); + b2 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 2 ) ); + b3 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 3 ) ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-63] = a[0,kr:kr+2]*b[kr:kr+2,0-63] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + + // Broadcast a[1,kr:kr+2]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( bfloat16 ) ) + ); + a_bf16_1 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); + + c_float_0p1 = _mm512_dpbf16_ps( c_float_0p1, a_bf16_0, b1 ); + c_float_0p2 = _mm512_dpbf16_ps( c_float_0p2, a_bf16_0, b2 ); + c_float_0p3 = _mm512_dpbf16_ps( c_float_0p3, a_bf16_0, b3 ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-63] = a[1,kr:kr+2]*b[kr:kr+2,0-63] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_1, b0 ); + + // Broadcast a[2,kr:kr+2]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 2 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( bfloat16 ) ) + ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); + + c_float_1p1 = _mm512_dpbf16_ps( c_float_1p1, a_bf16_1, b1 ); + c_float_1p2 = _mm512_dpbf16_ps( c_float_1p2, a_bf16_1, b2 ); + c_float_1p3 = _mm512_dpbf16_ps( c_float_1p3, a_bf16_1, b3 ); + + // Perform column direction mat-mul with k = 2. + // c[2,0-63] = a[2,kr:kr+2]*b[kr:kr+2,0-63] + c_float_2p0 = _mm512_dpbf16_ps( c_float_2p0, a_bf16_0, b0 ); + c_float_2p1 = _mm512_dpbf16_ps( c_float_2p1, a_bf16_0, b1 ); + c_float_2p2 = _mm512_dpbf16_ps( c_float_2p2, a_bf16_0, b2 ); + c_float_2p3 = _mm512_dpbf16_ps( c_float_2p3, a_bf16_0, b3 ); + } + + // Load alpha and beta + __m512 selector1 = _mm512_set1_ps( alpha ); + __m512 selector2 = _mm512_set1_ps( beta ); + + // Scale by alpha + c_float_0p0 = _mm512_mul_ps( selector1, c_float_0p0 ); + c_float_0p1 = _mm512_mul_ps( selector1, c_float_0p1 ); + c_float_0p2 = _mm512_mul_ps( selector1, c_float_0p2 ); + c_float_0p3 = _mm512_mul_ps( selector1, c_float_0p3 ); + + c_float_1p0 = _mm512_mul_ps( selector1, c_float_1p0 ); + c_float_1p1 = _mm512_mul_ps( selector1, c_float_1p1 ); + c_float_1p2 = _mm512_mul_ps( selector1, c_float_1p2 ); + c_float_1p3 = _mm512_mul_ps( selector1, c_float_1p3 ); + + c_float_2p0 = _mm512_mul_ps( selector1, c_float_2p0 ); + c_float_2p1 = _mm512_mul_ps( selector1, c_float_2p1 ); + c_float_2p2 = _mm512_mul_ps( selector1, c_float_2p2 ); + c_float_2p3 = _mm512_mul_ps( selector1, c_float_2p3 ); + + // Scale C by beta. + if ( beta != 0 ) + { + // c[0,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * 0 ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + selector1 = _mm512_loadu_ps( c + ( rs_c * 0 ) + ( 1*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_0p1 = _mm512_add_ps( selector1, c_float_0p1 ); + + // c[0,32-47] + selector1 = _mm512_loadu_ps( c + ( rs_c * 0 ) + ( 2*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_0p2 = _mm512_add_ps( selector1, c_float_0p2 ); + + // c[0,48-63] + selector1 = _mm512_loadu_ps( c + ( rs_c * 0 ) + ( 3*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_0p3 = _mm512_add_ps( selector1, c_float_0p3 ); + + // c[1,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * 1 ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + + // c[1,16-31] + selector1 = _mm512_loadu_ps( c + ( rs_c * 1 ) + ( 1*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_1p1 = _mm512_add_ps( selector1, c_float_1p1 ); + + // c[1,32-47] + selector1 = _mm512_loadu_ps( c + ( rs_c * 1 ) + ( 2*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_1p2 = _mm512_add_ps( selector1, c_float_1p2 ); + + // c[1,48-63] + selector1 = _mm512_loadu_ps( c + ( rs_c * 1 ) + ( 3*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_1p3 = _mm512_add_ps( selector1, c_float_1p3 ); + + // c[2,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * 2 ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); + + // c[2,16-31] + selector1 = _mm512_loadu_ps( c + ( rs_c * 2 ) + ( 1*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_2p1 = _mm512_add_ps( selector1, c_float_2p1 ); + + // c[2,32-47] + selector1 = _mm512_loadu_ps( c + ( rs_c * 2 ) + ( 2*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_2p2 = _mm512_add_ps( selector1, c_float_2p2 ); + + // c[2,48-63] + selector1 = _mm512_loadu_ps( c + ( rs_c * 2 ) + ( 3*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_2p3 = _mm512_add_ps( selector1, c_float_2p3 ); + } + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_3x64: + { + __m512 selector3; + __m512 selector4; + + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + selector1 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_op_c_j ); + selector2 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 1 * 16 ) ); + selector3 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 2 * 16 ) ); + selector4 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 3 * 16 ) ); + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_add_ps( selector2, c_float_0p1 ); + + // c[0,32-47] + c_float_0p2 = _mm512_add_ps( selector3, c_float_0p2 ); + + // c[0,48-63] + c_float_0p3 = _mm512_add_ps( selector4, c_float_0p3 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + + // c[1, 16-31] + c_float_1p1 = _mm512_add_ps( selector2, c_float_1p1 ); + + // c[1,32-47] + c_float_1p2 = _mm512_add_ps( selector3, c_float_1p2 ); + + // c[1,48-63] + c_float_1p3 = _mm512_add_ps( selector4, c_float_1p3 ); + + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); + + // c[2, 16-31] + c_float_2p1 = _mm512_add_ps( selector2, c_float_2p1 ); + + // c[2,32-47] + c_float_2p2 = _mm512_add_ps( selector3, c_float_2p2 ); + + // c[2,48-63] + c_float_2p3 = _mm512_add_ps( selector4, c_float_2p3 ); + } + else + { + // If original output was columns major, then by the time + // kernel sees it, the matrix would be accessed as if it were + // transposed. Due to this the bias array will be accessed by + // the ic index, and each bias element corresponds to an + // entire row of the transposed output array, instead of an + // entire column. + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 1 ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 2 ) ); + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_add_ps( selector1, c_float_0p1 ); + + // c[0,32-47] + c_float_0p2 = _mm512_add_ps( selector1, c_float_0p2 ); + + // c[0,48-63] + c_float_0p3 = _mm512_add_ps( selector1, c_float_0p3 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector2, c_float_1p0 ); + + // c[1, 16-31] + c_float_1p1 = _mm512_add_ps( selector2, c_float_1p1 ); + + // c[1,32-47] + c_float_1p2 = _mm512_add_ps( selector2, c_float_1p2 ); + + // c[1,48-63] + c_float_1p3 = _mm512_add_ps( selector2, c_float_1p3 ); + + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector3, c_float_2p0 ); + + // c[2, 16-31] + c_float_2p1 = _mm512_add_ps( selector3, c_float_2p1 ); + + // c[2,32-47] + c_float_2p2 = _mm512_add_ps( selector3, c_float_2p2 ); + + // c[2,48-63] + c_float_2p3 = _mm512_add_ps( selector3, c_float_2p3 ); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_3x64: + { + selector1 = _mm512_setzero_ps(); + + // c[0,0-15] + c_float_0p0 = _mm512_max_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_max_ps( selector1, c_float_0p1 ); + + // c[0,32-47] + c_float_0p2 = _mm512_max_ps( selector1, c_float_0p2 ); + + // c[0,48-63] + c_float_0p3 = _mm512_max_ps( selector1, c_float_0p3 ); + + // c[1,0-15] + c_float_1p0 = _mm512_max_ps( selector1, c_float_1p0 ); + + // c[1,16-31] + c_float_1p1 = _mm512_max_ps( selector1, c_float_1p1 ); + + // c[1,32-47] + c_float_1p2 = _mm512_max_ps( selector1, c_float_1p2 ); + + // c[1,48-63] + c_float_1p3 = _mm512_max_ps( selector1, c_float_1p3 ); + + // c[2,0-15] + c_float_2p0 = _mm512_max_ps( selector1, c_float_2p0 ); + + // c[2,16-31] + c_float_2p1 = _mm512_max_ps( selector1, c_float_2p1 ); + + // c[2,32-47] + c_float_2p2 = _mm512_max_ps( selector1, c_float_2p2 ); + + // c[2,48-63] + c_float_2p3 = _mm512_max_ps( selector1, c_float_2p3 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_3x64: + { + selector1 = _mm512_setzero_ps(); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_0p0) + + // c[0, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_0p1) + + // c[0, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_0p2) + + // c[0, 48-63] + RELU_SCALE_OP_F32_AVX512(c_float_0p3) + + // c[1, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_1p0) + + // c[1, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_1p1) + + // c[1, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_1p2) + + // c[1, 48-63] + RELU_SCALE_OP_F32_AVX512(c_float_1p3) + + // c[2, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_2p0) + + // c[2, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_2p1) + + // c[2, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_2p2) + + // c[2, 48-63] + RELU_SCALE_OP_F32_AVX512(c_float_2p3) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_DOWNSCALE_3x64: + { + // c[0, 0-15] + CVT_F32_BF16(c_float_0p0,0,0); + + // c[0, 16-31] + CVT_F32_BF16(c_float_0p1,0,1); + + // c[0, 32-47] + CVT_F32_BF16(c_float_0p2,0,2); + + // c[0, 48-63] + CVT_F32_BF16(c_float_0p3,0,3); + + // c[1, 0-15] + CVT_F32_BF16(c_float_1p0,1,0); + + // c[1, 16-31] + CVT_F32_BF16(c_float_1p1,1,1); + + // c[1, 32-47] + CVT_F32_BF16(c_float_1p2,1,2); + + // c[1, 48-63] + CVT_F32_BF16(c_float_1p3,1,3); + + // c[2, 0-15] + CVT_F32_BF16(c_float_2p0,2,0); + + // c[2, 16-31] + CVT_F32_BF16(c_float_2p1,2,1); + + // c[2, 32-47] + CVT_F32_BF16(c_float_2p2,2,2); + + // c[2, 48-63] + CVT_F32_BF16(c_float_2p3,2,3); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_3x64_DISABLE: + ; + + // Store the results. + // c[0,0-15] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 0*16 ), c_float_0p0 ); + + // c[0, 16-31] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 1*16 ), c_float_0p1 ); + + // c[0,32-47] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 2*16 ), c_float_0p2 ); + + // c[0,48-63] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 3*16 ), c_float_0p3 ); + + // c[1,0-15] + _mm512_storeu_ps( c + ( rs_c * 1 ) + ( 0*16 ), c_float_1p0 ); + + // c[1,16-31] + _mm512_storeu_ps( c + ( rs_c * 1 ) + ( 1*16 ), c_float_1p1 ); + + // c[1,32-47] + _mm512_storeu_ps( c + ( rs_c * 1 ) + ( 2*16 ), c_float_1p2 ); + + // c[1,48-63] + _mm512_storeu_ps( c + ( rs_c * 1 ) + ( 3*16 ), c_float_1p3 ); + + // c[2,0-15] + _mm512_storeu_ps( c + ( rs_c * 2 ) + ( 0*16 ), c_float_2p0 ); + + // c[2,16-31] + _mm512_storeu_ps( c + ( rs_c * 2 ) + ( 1*16 ), c_float_2p1 ); + + // c[2,32-47] + _mm512_storeu_ps( c + ( rs_c * 2 ) + ( 2*16 ), c_float_2p2 ); + + // c[2,48-63] + _mm512_storeu_ps( c + ( rs_c * 2 ) + ( 3*16 ), c_float_2p3 ); +} + +// 2x64 bf16 kernel +LPGEMM_M_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_2x64) +{ + static void* post_ops_labels[] = + { + &&POST_OPS_2x64_DISABLE, + &&POST_OPS_BIAS_2x64, + &&POST_OPS_RELU_2x64, + &&POST_OPS_RELU_SCALE_2x64, + &&POST_OPS_DOWNSCALE_2x64 + }; + dim_t k_full_pieces = k0 / 2; + dim_t k_partial_pieces = k0 % 2; + + int32_t a_kfringe_buf = 0; + // B matrix storage bfloat type + __m512bh b0; + __m512bh b1; + __m512bh b2; + __m512bh b3; + + // A matrix storage bfloat type + __m512bh a_bf16_0; + __m512bh a_bf16_1; + + // Registers to use for accumulating C. + __m512 c_float_0p0 = _mm512_setzero_ps(); + __m512 c_float_0p1 = _mm512_setzero_ps(); + __m512 c_float_0p2 = _mm512_setzero_ps(); + __m512 c_float_0p3 = _mm512_setzero_ps(); + + __m512 c_float_1p0 = _mm512_setzero_ps(); + __m512 c_float_1p1 = _mm512_setzero_ps(); + __m512 c_float_1p2 = _mm512_setzero_ps(); + __m512 c_float_1p3 = _mm512_setzero_ps(); + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + b0 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + b1 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 1 ) ); + b2 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 2 ) ); + b3 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 3 ) ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-63] = a[0,kr:kr+2]*b[kr:kr+2,0-63] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + + // Broadcast a[1,kr:kr+2]. + a_bf16_1 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + + c_float_0p1 = _mm512_dpbf16_ps( c_float_0p1, a_bf16_0, b1 ); + c_float_0p2 = _mm512_dpbf16_ps( c_float_0p2, a_bf16_0, b2 ); + c_float_0p3 = _mm512_dpbf16_ps( c_float_0p3, a_bf16_0, b3 ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-63] = a[1,kr:kr+2]*b[kr:kr+2,0-63] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_1, b0 ); + c_float_1p1 = _mm512_dpbf16_ps( c_float_1p1, a_bf16_1, b1 ); + c_float_1p2 = _mm512_dpbf16_ps( c_float_1p2, a_bf16_1, b2 ); + c_float_1p3 = _mm512_dpbf16_ps( c_float_1p3, a_bf16_1, b3 ); + } + + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + b0 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+2]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( bfloat16 ) ) + ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); + + b1 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); + b2 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 2 ) ); + b3 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 3 ) ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-63] = a[0,kr:kr+2]*b[kr:kr+2,0-63] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + + // Broadcast a[1,kr:kr+2]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( bfloat16 ) ) + ); + a_bf16_1 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); + + c_float_0p1 = _mm512_dpbf16_ps( c_float_0p1, a_bf16_0, b1 ); + c_float_0p2 = _mm512_dpbf16_ps( c_float_0p2, a_bf16_0, b2 ); + c_float_0p3 = _mm512_dpbf16_ps( c_float_0p3, a_bf16_0, b3 ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-63] = a[1,kr:kr+2]*b[kr:kr+2,0-63] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_1, b0 ); + c_float_1p1 = _mm512_dpbf16_ps( c_float_1p1, a_bf16_1, b1 ); + c_float_1p2 = _mm512_dpbf16_ps( c_float_1p2, a_bf16_1, b2 ); + c_float_1p3 = _mm512_dpbf16_ps( c_float_1p3, a_bf16_1, b3 ); + } + + + // Load alpha and beta + __m512 selector1 = _mm512_set1_ps( alpha ); + __m512 selector2 = _mm512_set1_ps( beta ); + + // Scale by alpha + c_float_0p0 = _mm512_mul_ps( selector1, c_float_0p0 ); + c_float_0p1 = _mm512_mul_ps( selector1, c_float_0p1 ); + c_float_0p2 = _mm512_mul_ps( selector1, c_float_0p2 ); + c_float_0p3 = _mm512_mul_ps( selector1, c_float_0p3 ); + + c_float_1p0 = _mm512_mul_ps( selector1, c_float_1p0 ); + c_float_1p1 = _mm512_mul_ps( selector1, c_float_1p1 ); + c_float_1p2 = _mm512_mul_ps( selector1, c_float_1p2 ); + c_float_1p3 = _mm512_mul_ps( selector1, c_float_1p3 ); + + // Scale C by beta. + if ( beta != 0 ) + { + // c[0,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * 0 ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + selector1 = _mm512_loadu_ps( c + ( rs_c * 0 ) + ( 1*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_0p1 = _mm512_add_ps( selector1, c_float_0p1 ); + + // c[0,32-47] + selector1 = _mm512_loadu_ps( c + ( rs_c * 0 ) + ( 2*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_0p2 = _mm512_add_ps( selector1, c_float_0p2 ); + + // c[0,48-63] + selector1 = _mm512_loadu_ps( c + ( rs_c * 0 ) + ( 3*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_0p3 = _mm512_add_ps( selector1, c_float_0p3 ); + + // c[1,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * 1 ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + + // c[1,16-31] + selector1 = _mm512_loadu_ps( c + ( rs_c * 1 ) + ( 1*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_1p1 = _mm512_add_ps( selector1, c_float_1p1 ); + + // c[1,32-47] + selector1 = _mm512_loadu_ps( c + ( rs_c * 1 ) + ( 2*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_1p2 = _mm512_add_ps( selector1, c_float_1p2 ); + + // c[1,48-63] + selector1 = _mm512_loadu_ps( c + ( rs_c * 1 ) + ( 3*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_1p3 = _mm512_add_ps( selector1, c_float_1p3 ); + } + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_2x64: + { + __m512 selector3; + __m512 selector4; + + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + selector1 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_op_c_j ); + selector2 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 1 * 16 ) ); + selector3 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 2 * 16 ) ); + selector4 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 3 * 16 ) ); + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_add_ps( selector2, c_float_0p1 ); + + // c[0,32-47] + c_float_0p2 = _mm512_add_ps( selector3, c_float_0p2 ); + + // c[0,48-63] + c_float_0p3 = _mm512_add_ps( selector4, c_float_0p3 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + + // c[1, 16-31] + c_float_1p1 = _mm512_add_ps( selector2, c_float_1p1 ); + + // c[1,32-47] + c_float_1p2 = _mm512_add_ps( selector3, c_float_1p2 ); + + // c[1,48-63] + c_float_1p3 = _mm512_add_ps( selector4, c_float_1p3 ); + } + else + { + // If original output was columns major, then by the time + // kernel sees it, the matrix would be accessed as if it were + // transposed. Due to this the bias array will be accessed by + // the ic index, and each bias element corresponds to an + // entire row of the transposed output array, instead of an + // entire column. + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 1 ) ); + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_add_ps( selector1, c_float_0p1 ); + + // c[0,32-47] + c_float_0p2 = _mm512_add_ps( selector1, c_float_0p2 ); + + // c[0,48-63] + c_float_0p3 = _mm512_add_ps( selector1, c_float_0p3 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector2, c_float_1p0 ); + + // c[1, 16-31] + c_float_1p1 = _mm512_add_ps( selector2, c_float_1p1 ); + + // c[1,32-47] + c_float_1p2 = _mm512_add_ps( selector2, c_float_1p2 ); + + // c[1,48-63] + c_float_1p3 = _mm512_add_ps( selector2, c_float_1p3 ); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_2x64: + { + selector1 = _mm512_setzero_ps(); + + // c[0,0-15] + c_float_0p0 = _mm512_max_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_max_ps( selector1, c_float_0p1 ); + + // c[0,32-47] + c_float_0p2 = _mm512_max_ps( selector1, c_float_0p2 ); + + // c[0,48-63] + c_float_0p3 = _mm512_max_ps( selector1, c_float_0p3 ); + + // c[1,0-15] + c_float_1p0 = _mm512_max_ps( selector1, c_float_1p0 ); + + // c[1,16-31] + c_float_1p1 = _mm512_max_ps( selector1, c_float_1p1 ); + + // c[1,32-47] + c_float_1p2 = _mm512_max_ps( selector1, c_float_1p2 ); + + // c[1,48-63] + c_float_1p3 = _mm512_max_ps( selector1, c_float_1p3 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_2x64: + { + selector1 = _mm512_setzero_ps(); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_0p0) + + // c[0, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_0p1) + + // c[0, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_0p2) + + // c[0, 48-63] + RELU_SCALE_OP_F32_AVX512(c_float_0p3) + + // c[1, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_1p0) + + // c[1, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_1p1) + + // c[1, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_1p2) + + // c[1, 48-63] + RELU_SCALE_OP_F32_AVX512(c_float_1p3) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } + POST_OPS_DOWNSCALE_2x64: + { + // c[0, 0-15] + CVT_F32_BF16(c_float_0p0,0,0); + + // c[0, 16-31] + CVT_F32_BF16(c_float_0p1,0,1); + + // c[0, 32-47] + CVT_F32_BF16(c_float_0p2,0,2); + + // c[0, 48-63] + CVT_F32_BF16(c_float_0p3,0,3); + + // c[1, 0-15] + CVT_F32_BF16(c_float_1p0,1,0); + + // c[1, 16-31] + CVT_F32_BF16(c_float_1p1,1,1); + + // c[1, 32-47] + CVT_F32_BF16(c_float_1p2,1,2); + + // c[1, 48-63] + CVT_F32_BF16(c_float_1p3,1,3); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_2x64_DISABLE: + ; + + // Store the results. + // c[0,0-15] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 0*16 ), c_float_0p0 ); + + // c[0, 16-31] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 1*16 ), c_float_0p1 ); + + // c[0,32-47] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 2*16 ), c_float_0p2 ); + + // c[0,48-63] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 3*16 ), c_float_0p3 ); + + // c[1,0-15] + _mm512_storeu_ps( c + ( rs_c * 1 ) + ( 0*16 ), c_float_1p0 ); + + // c[1,16-31] + _mm512_storeu_ps( c + ( rs_c * 1 ) + ( 1*16 ), c_float_1p1 ); + + // c[1,32-47] + _mm512_storeu_ps( c + ( rs_c * 1 ) + ( 2*16 ), c_float_1p2 ); + + // c[1,48-63] + _mm512_storeu_ps( c + ( rs_c * 1 ) + ( 3*16 ), c_float_1p3 ); +} + +// 1x64 bf16 kernel +LPGEMM_M_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_1x64) +{ + static void* post_ops_labels[] = + { + &&POST_OPS_1x64_DISABLE, + &&POST_OPS_BIAS_1x64, + &&POST_OPS_RELU_1x64, + &&POST_OPS_RELU_SCALE_1x64, + &&POST_OPS_DOWNSCALE_1x64 + }; + dim_t k_full_pieces = k0 / 2; + dim_t k_partial_pieces = k0 % 2; + + int32_t a_kfringe_buf = 0; + + // Registers to use for accumulating C. + __m512 c_float_0p0 = _mm512_setzero_ps(); + __m512 c_float_0p1 = _mm512_setzero_ps(); + __m512 c_float_0p2 = _mm512_setzero_ps(); + __m512 c_float_0p3 = _mm512_setzero_ps(); + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + __m512bh b0 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr] + __m512bh a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + __m512bh b1 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 1 ) ); + __m512bh b2 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 2 ) ); + __m512bh b3 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 3 ) ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-63] = a[0,kr:kr+2]*b[kr:kr+2,0-63] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + c_float_0p1 = _mm512_dpbf16_ps( c_float_0p1, a_bf16_0, b1 ); + c_float_0p2 = _mm512_dpbf16_ps( c_float_0p2, a_bf16_0, b2 ); + c_float_0p3 = _mm512_dpbf16_ps( c_float_0p3, a_bf16_0, b3 ); + } + + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + __m512bh b0 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+2]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( bfloat16 ) ) + ); + __m512bh a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); + + __m512bh b1 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); + __m512bh b2 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 2 ) ); + __m512bh b3 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 3 ) ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-63] = a[0,kr:kr+2]*b[kr:kr+2,0-63] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + c_float_0p1 = _mm512_dpbf16_ps( c_float_0p1, a_bf16_0, b1 ); + c_float_0p2 = _mm512_dpbf16_ps( c_float_0p2, a_bf16_0, b2 ); + c_float_0p3 = _mm512_dpbf16_ps( c_float_0p3, a_bf16_0, b3 ); + } + + // Load alpha and beta + __m512 selector1 = _mm512_set1_ps( alpha ); + __m512 selector2 = _mm512_set1_ps( beta ); + + // Scale by alpha + c_float_0p0 = _mm512_mul_ps( selector1, c_float_0p0 ); + c_float_0p1 = _mm512_mul_ps( selector1, c_float_0p1 ); + c_float_0p2 = _mm512_mul_ps( selector1, c_float_0p2 ); + c_float_0p3 = _mm512_mul_ps( selector1, c_float_0p3 ); + + // Scale C by beta. + if ( beta != 0) + { + // c[0,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * 0 ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + selector1 = _mm512_loadu_ps( c + ( rs_c * 0 ) + ( 1*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_0p1 = _mm512_add_ps( selector1, c_float_0p1 ); + + // c[0,32-47] + selector1 = _mm512_loadu_ps( c + ( rs_c * 0 ) + ( 2*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_0p2 = _mm512_add_ps( selector1, c_float_0p2 ); + + // c[0,48-63] + selector1 = _mm512_loadu_ps( c + ( rs_c * 0 ) + ( 3*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_0p3 = _mm512_add_ps( selector1, c_float_0p3 ); + } + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_1x64: + { + __m512 selector3; + __m512 selector4; + + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + selector1 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_op_c_j ); + selector2 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 1 * 16 ) ); + selector3 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 2 * 16 ) ); + selector4 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 3 * 16 ) ); + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_add_ps( selector2, c_float_0p1 ); + + // c[0,32-47] + c_float_0p2 = _mm512_add_ps( selector3, c_float_0p2 ); + + // c[0,48-63] + c_float_0p3 = _mm512_add_ps( selector4, c_float_0p3 ); + } + else + { + // If original output was columns major, then by the time + // kernel sees it, the matrix would be accessed as if it were + // transposed. Due to this the bias array will be accessed by + // the ic index, and each bias element corresponds to an + // entire row of the transposed output array, instead of an + // entire column. + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 0 ) ); + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_add_ps( selector1, c_float_0p1 ); + + // c[0,32-47] + c_float_0p2 = _mm512_add_ps( selector1, c_float_0p2 ); + + // c[0,48-63] + c_float_0p3 = _mm512_add_ps( selector1, c_float_0p3 ); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_1x64: + { + selector1 = _mm512_setzero_ps(); + + // c[0,0-15] + c_float_0p0 = _mm512_max_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_max_ps( selector1, c_float_0p1 ); + + // c[0,32-47] + c_float_0p2 = _mm512_max_ps( selector1, c_float_0p2 ); + + // c[0,48-63] + c_float_0p3 = _mm512_max_ps( selector1, c_float_0p3 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_1x64: + { + selector1 = _mm512_setzero_ps(); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_0p0) + + // c[0, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_0p1) + + // c[0, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_0p2) + + // c[0, 48-63] + RELU_SCALE_OP_F32_AVX512(c_float_0p3) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_DOWNSCALE_1x64: + { + // c[0, 0-15] + CVT_F32_BF16(c_float_0p0,0,0); + + // c[0, 16-31] + CVT_F32_BF16(c_float_0p1,0,1); + + // c[0, 32-47] + CVT_F32_BF16(c_float_0p2,0,2); + + // c[0, 48-63] + CVT_F32_BF16(c_float_0p3,0,3); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_1x64_DISABLE: + ; + + // Store the accumulated results. + // c[0,0-15] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 0*16 ), c_float_0p0 ); + + // c[0, 16-31] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 1*16 ), c_float_0p1 ); + + // c[0,32-47] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 2*16 ), c_float_0p2 ); + + // c[0,48-63] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 3*16 ), c_float_0p3 ); +} +#endif diff --git a/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_mn_fringe_bf16_amd512vnni.c b/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_mn_fringe_bf16_amd512vnni.c new file mode 100644 index 0000000000..6e985f154f --- /dev/null +++ b/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_mn_fringe_bf16_amd512vnni.c @@ -0,0 +1,5843 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include + +#include "blis.h" +#include "lpgemm_kernels.h" +#include "lpgemm_f32_kern_macros.h" + +#ifdef BLIS_KERNELS_ZEN4 +// 5xlt16 bf16 fringe kernel +LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_5xlt16) +{ + static void* post_ops_labels[] = + { + &&POST_OPS_5xLT16_DISABLE, + &&POST_OPS_BIAS_5xLT16, + &&POST_OPS_RELU_5xLT16, + &&POST_OPS_RELU_SCALE_5xLT16, + &&POST_OPS_DOWNSCALE_5xLT16 + }; + dim_t k_full_pieces = k0 / 2; + dim_t k_partial_pieces = k0 % 2; + + int32_t a_kfringe_buf = 0; + + // B matrix storage bfloat type + __m512bh b0; + + // A matrix storage bfloat type + __m512bh a_bf16_0; + + // For corner cases. + float buf0[16]; + float buf1[16]; + float buf2[16]; + float buf3[16]; + float buf4[16]; + + // Registers to use for accumulating C. + __m512 c_float_0p0 = _mm512_setzero_ps(); + + __m512 c_float_1p0 = _mm512_setzero_ps(); + + __m512 c_float_2p0 = _mm512_setzero_ps(); + + __m512 c_float_3p0 = _mm512_setzero_ps(); + + __m512 c_float_4p0 = _mm512_setzero_ps(); + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + b0 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-15] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + + // Broadcast a[1,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-15] = a[1,kr:kr+2]*b[kr:kr+2,0-15] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_0, b0 ); + + // Broadcast a[2,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[2,0-15] = a[2,kr:kr+2]*b[kr:kr+2,0-15] + c_float_2p0 = _mm512_dpbf16_ps( c_float_2p0, a_bf16_0, b0 ); + + // Broadcast a[3,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 3 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[3,0-15] = a[3,kr:kr+2]*b[kr:kr+2,0-15] + c_float_3p0 = _mm512_dpbf16_ps( c_float_3p0, a_bf16_0, b0 ); + + // Broadcast a[4,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 4 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[4,0-15] = a[4,kr:kr+2]*b[kr:kr+2,0-15] + c_float_4p0 = _mm512_dpbf16_ps( c_float_4p0, a_bf16_0, b0 ); + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + b0 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+2]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-15] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + + // Broadcast a[1,kr:kr+2]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-15] = a[1,kr:kr+2]*b[kr:kr+2,0-15] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_0, b0 ); + + // Broadcast a[2,kr:kr+2]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 2 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[2,0-15] = a[2,kr:kr+2]*b[kr:kr+2,0-15] + c_float_2p0 = _mm512_dpbf16_ps( c_float_2p0, a_bf16_0, b0 ); + + // Broadcast a[3,kr:kr+2]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 3 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[3,0-15] = a[3,kr:kr+2]*b[kr:kr+2,0-15] + c_float_3p0 = _mm512_dpbf16_ps( c_float_3p0, a_bf16_0, b0 ); + + // Broadcast a[4,kr:kr+2]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 4 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[4,0-15] = a[4,kr:kr+2]*b[kr:kr+2,0-15] + c_float_4p0 = _mm512_dpbf16_ps( c_float_4p0, a_bf16_0, b0 ); + } + + // Load alpha and beta + __m512 selector1 = _mm512_set1_ps( alpha ); + __m512 selector2 = _mm512_set1_ps( beta ); + + // Scale by alpha + c_float_0p0 = _mm512_mul_ps( selector1, c_float_0p0 ); + + c_float_1p0 = _mm512_mul_ps( selector1, c_float_1p0 ); + + c_float_2p0 = _mm512_mul_ps( selector1, c_float_2p0 ); + + c_float_3p0 = _mm512_mul_ps( selector1, c_float_3p0 ); + + c_float_4p0 = _mm512_mul_ps( selector1, c_float_4p0 ); + + // Scale C by beta. + if ( beta != 0 ) + { + memcpy( buf0, ( c + ( rs_c * 0 ) ), ( n0_rem * sizeof( float ) ) ); + memcpy( buf1, ( c + ( rs_c * 1 ) ), ( n0_rem * sizeof( float ) ) ); + memcpy( buf2, ( c + ( rs_c * 2 ) ), ( n0_rem * sizeof( float ) ) ); + memcpy( buf3, ( c + ( rs_c * 3 ) ), ( n0_rem * sizeof( float ) ) ); + memcpy( buf4, ( c + ( rs_c * 4 ) ), ( n0_rem * sizeof( float ) ) ); + + // c[0,0-15] + selector1 = _mm512_loadu_ps( buf0 ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[1,0-15] + selector1 = _mm512_loadu_ps( buf1 ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + + // c[2,0-15] + selector1 = _mm512_loadu_ps( buf2 ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); + + // c[3,0-15] + selector1 = _mm512_loadu_ps( buf3 ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_3p0 = _mm512_add_ps( selector1, c_float_3p0 ); + + // c[4,0-15] + selector1 = _mm512_loadu_ps( buf4 ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_4p0 = _mm512_add_ps( selector1, c_float_4p0 ); + } + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_5xLT16: + { + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + memcpy( buf0, ( ( float* )post_ops_list_temp->op_args1 + + post_op_c_j ), ( n0_rem * sizeof( float ) ) ); + selector1 = _mm512_loadu_ps( buf0 ); + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); + + // c[3,0-15] + c_float_3p0 = _mm512_add_ps( selector1, c_float_3p0 ); + + // c[4,0-15] + c_float_4p0 = _mm512_add_ps( selector1, c_float_4p0 ); + } + else + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 1 ) ); + __m512 selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 2 ) ); + __m512 selector4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 3 ) ); + __m512 selector5 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 4 ) ); + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector2, c_float_1p0 ); + + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector3, c_float_2p0 ); + + // c[3,0-15] + c_float_3p0 = _mm512_add_ps( selector4, c_float_3p0 ); + + // c[4,0-15] + c_float_4p0 = _mm512_add_ps( selector5, c_float_4p0 ); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_5xLT16: + { + selector1 = _mm512_setzero_ps(); + + // c[0,0-15] + c_float_0p0 = _mm512_max_ps( selector1, c_float_0p0 ); + + // c[1,0-15] + c_float_1p0 = _mm512_max_ps( selector1, c_float_1p0 ); + + // c[2,0-15] + c_float_2p0 = _mm512_max_ps( selector1, c_float_2p0 ); + + // c[3,0-15] + c_float_3p0 = _mm512_max_ps( selector1, c_float_3p0 ); + + // c[4,0-15] + c_float_4p0 = _mm512_max_ps( selector1, c_float_4p0 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_5xLT16: + { + selector1 = _mm512_setzero_ps(); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_0p0) + + // c[1, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_1p0) + + // c[2, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_2p0) + + // c[3, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_3p0) + + // c[4, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_4p0) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_DOWNSCALE_5xLT16: + { + // c[0, 0-15] + CVT_F32_BF16_LT16(c_float_0p0,0,0); + + // c[1, 0-15] + CVT_F32_BF16_LT16(c_float_1p0,1,0); + + // c[2, 0-15] + CVT_F32_BF16_LT16(c_float_2p0,2,0); + + // c[3, 0-15] + CVT_F32_BF16_LT16(c_float_3p0,3,0); + + // c[4, 0-15] + CVT_F32_BF16_LT16(c_float_4p0,4,0); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_5xLT16_DISABLE: + ; + + // Store the results. + // c[0,0-15] + _mm512_storeu_ps( buf0, c_float_0p0 ); + + // c[1,0-15] + _mm512_storeu_ps( buf1, c_float_1p0 ); + + // c[2,0-15] + _mm512_storeu_ps( buf2, c_float_2p0 ); + + // c[3,0-15] + _mm512_storeu_ps( buf3, c_float_3p0 ); + + // c[4,0-15] + _mm512_storeu_ps( buf4, c_float_4p0 ); + + // Memcpy partial parts. + // c[0,0-15] + memcpy( c + ( rs_c * 0 ) + ( 0*16 ), buf0, ( n0_rem * sizeof( float ) ) ); + + // c[1,0-15] + memcpy( c + ( rs_c * 1 ) + ( 0*16 ), buf1, ( n0_rem * sizeof( float ) ) ); + + // c[2,0-15] + memcpy( c + ( rs_c * 2 ) + ( 0*16 ), buf2, ( n0_rem * sizeof( float ) ) ); + + // c[3,0-15] + memcpy( c + ( rs_c * 3 ) + ( 0*16 ), buf3, ( n0_rem * sizeof( float ) ) ); + + // c[4,0-15] + memcpy( c + ( rs_c * 4 ) + ( 0*16 ), buf4, ( n0_rem * sizeof( float ) ) ); + +} + +// 4xlt16 bf16 fringe kernel +LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_4xlt16) +{ + static void* post_ops_labels[] = + { + &&POST_OPS_4xLT16_DISABLE, + &&POST_OPS_BIAS_4xLT16, + &&POST_OPS_RELU_4xLT16, + &&POST_OPS_RELU_SCALE_4xLT16, + &&POST_OPS_DOWNSCALE_4xLT16 + }; + dim_t k_full_pieces = k0 / 2; + dim_t k_partial_pieces = k0 % 2; + + int32_t a_kfringe_buf = 0; + + // B matrix storage bfloat type + __m512bh b0; + + // A matrix storage bfloat type + __m512bh a_bf16_0; + + // For corner cases. + float buf0[16]; + float buf1[16]; + float buf2[16]; + float buf3[16]; + + + // Registers to use for accumulating C. + __m512 c_float_0p0 = _mm512_setzero_ps(); + + __m512 c_float_1p0 = _mm512_setzero_ps(); + + __m512 c_float_2p0 = _mm512_setzero_ps(); + + __m512 c_float_3p0 = _mm512_setzero_ps(); + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + b0 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-15] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + + // Broadcast a[1,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-15] = a[1,kr:kr+2]*b[kr:kr+2,0-15] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_0, b0 ); + + // Broadcast a[2,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[2,0-15] = a[2,kr:kr+2]*b[kr:kr+2,0-15] + c_float_2p0 = _mm512_dpbf16_ps( c_float_2p0, a_bf16_0, b0 ); + + // Broadcast a[3,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 3 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[3,0-15] = a[3,kr:kr+2]*b[kr:kr+2,0-15] + c_float_3p0 = _mm512_dpbf16_ps( c_float_3p0, a_bf16_0, b0 ); + } + // Handle k remainder. + + if ( k_partial_pieces > 0 ) + { + b0 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+2]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-15] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + + // Broadcast a[1,kr:kr+2]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-15] = a[1,kr:kr+2]*b[kr:kr+2,0-15] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_0, b0 ); + + // Broadcast a[2,kr:kr+2]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 2 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[2,0-15] = a[2,kr:kr+2]*b[kr:kr+2,0-15] + c_float_2p0 = _mm512_dpbf16_ps( c_float_2p0, a_bf16_0, b0 ); + + // Broadcast a[3,kr:kr+2]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 3 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[3,0-15] = a[3,kr:kr+2]*b[kr:kr+2,0-15] + c_float_3p0 = _mm512_dpbf16_ps( c_float_3p0, a_bf16_0, b0 ); + } + + // Load alpha and beta + __m512 selector1 = _mm512_set1_ps( alpha ); + __m512 selector2 = _mm512_set1_ps( beta ); + + // Scale by alpha + c_float_0p0 = _mm512_mul_ps( selector1, c_float_0p0 ); + + c_float_1p0 = _mm512_mul_ps( selector1, c_float_1p0 ); + + c_float_2p0 = _mm512_mul_ps( selector1, c_float_2p0 ); + + c_float_3p0 = _mm512_mul_ps( selector1, c_float_3p0 ); + + // Scale C by beta. + if ( beta != 0 ) + { + memcpy( buf0, ( c + ( rs_c * 0 ) ), ( n0_rem * sizeof( float ) ) ); + memcpy( buf1, ( c + ( rs_c * 1 ) ), ( n0_rem * sizeof( float ) ) ); + memcpy( buf2, ( c + ( rs_c * 2 ) ), ( n0_rem * sizeof( float ) ) ); + memcpy( buf3, ( c + ( rs_c * 3 ) ), ( n0_rem * sizeof( float ) ) ); + + // c[0,0-15] + selector1 = _mm512_loadu_ps( buf0 ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[1,0-15] + selector1 = _mm512_loadu_ps( buf1 ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + + // c[2,0-15] + selector1 = _mm512_loadu_ps( buf2 ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); + + // c[3,0-15] + selector1 = _mm512_loadu_ps( buf3 ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_3p0 = _mm512_add_ps( selector1, c_float_3p0 ); + } + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_4xLT16: + { + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + memcpy( buf0, ( ( float* )post_ops_list_temp->op_args1 + + post_op_c_j ), ( n0_rem * sizeof( float ) ) ); + selector1 = _mm512_loadu_ps( buf0 ); + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); + + // c[3,0-15] + c_float_3p0 = _mm512_add_ps( selector1, c_float_3p0 ); + } + else + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 1 ) ); + __m512 selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 2 ) ); + __m512 selector4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 3 ) ); + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector2, c_float_1p0 ); + + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector3, c_float_2p0 ); + + // c[3,0-15] + c_float_3p0 = _mm512_add_ps( selector4, c_float_3p0 ); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_4xLT16: + { + selector1 = _mm512_setzero_ps(); + + // c[0,0-15] + c_float_0p0 = _mm512_max_ps( selector1, c_float_0p0 ); + + // c[1,0-15] + c_float_1p0 = _mm512_max_ps( selector1, c_float_1p0 ); + + // c[2,0-15] + c_float_2p0 = _mm512_max_ps( selector1, c_float_2p0 ); + + // c[3,0-15] + c_float_3p0 = _mm512_max_ps( selector1, c_float_3p0 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_4xLT16: + { + selector1 = _mm512_setzero_ps(); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_0p0) + + // c[1, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_1p0) + + // c[2, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_2p0) + + // c[3, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_3p0) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_DOWNSCALE_4xLT16: + { + // c[0, 0-15] + CVT_F32_BF16_LT16(c_float_0p0,0,0); + + // c[1, 0-15] + CVT_F32_BF16_LT16(c_float_1p0,1,0); + + // c[2, 0-15] + CVT_F32_BF16_LT16(c_float_2p0,2,0); + + // c[3, 0-15] + CVT_F32_BF16_LT16(c_float_3p0,3,0); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_4xLT16_DISABLE: + ; + + // Store the results. + // c[0,0-15] + _mm512_storeu_ps( buf0, c_float_0p0 ); + + // c[1,0-15] + _mm512_storeu_ps( buf1, c_float_1p0 ); + + // c[2,0-15] + _mm512_storeu_ps( buf2, c_float_2p0 ); + + // c[3,0-15] + _mm512_storeu_ps( buf3, c_float_3p0 ); + + // Memcpy partial parts. + // c[0,0-15] + memcpy( c + ( rs_c * 0 ) + ( 0*16 ), buf0, ( n0_rem * sizeof( float ) ) ); + + // c[1,0-15] + memcpy( c + ( rs_c * 1 ) + ( 0*16 ), buf1, ( n0_rem * sizeof( float ) ) ); + + // c[2,0-15] + memcpy( c + ( rs_c * 2 ) + ( 0*16 ), buf2, ( n0_rem * sizeof( float ) ) ); + + // c[3,0-15] + memcpy( c + ( rs_c * 3 ) + ( 0*16 ), buf3, ( n0_rem * sizeof( float ) ) ); + +} + +// 3xlt16 bf16 fringe kernel +LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_3xlt16) +{ + static void* post_ops_labels[] = + { + &&POST_OPS_3xLT16_DISABLE, + &&POST_OPS_BIAS_3xLT16, + &&POST_OPS_RELU_3xLT16, + &&POST_OPS_RELU_SCALE_3xLT16, + &&POST_OPS_DOWNSCALE_3xLT16 + }; + dim_t k_full_pieces = k0 / 2; + dim_t k_partial_pieces = k0 % 2; + + int32_t a_kfringe_buf = 0; + + // B matrix storage bfloat type + __m512bh b0; + + // A matrix storage bfloat type + __m512bh a_bf16_0; + + // For corner cases. + float buf0[16]; + float buf1[16]; + float buf2[16]; + + + // Registers to use for accumulating C. + __m512 c_float_0p0 = _mm512_setzero_ps(); + + __m512 c_float_1p0 = _mm512_setzero_ps(); + + __m512 c_float_2p0 = _mm512_setzero_ps(); + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + b0 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-15] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + + // Broadcast a[1,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-15] = a[1,kr:kr+2]*b[kr:kr+2,0-15] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_0, b0 ); + + // Broadcast a[2,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[2,0-15] = a[2,kr:kr+2]*b[kr:kr+2,0-15] + c_float_2p0 = _mm512_dpbf16_ps( c_float_2p0, a_bf16_0, b0 ); + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + b0 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+2]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-15] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + + // Broadcast a[1,kr:kr+2]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-15] = a[1,kr:kr+2]*b[kr:kr+2,0-15] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_0, b0 ); + + // Broadcast a[2,kr:kr+2]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 2 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[2,0-15] = a[2,kr:kr+2]*b[kr:kr+2,0-15] + c_float_2p0 = _mm512_dpbf16_ps( c_float_2p0, a_bf16_0, b0 ); + } + + // Load alpha and beta + __m512 selector1 = _mm512_set1_ps( alpha ); + __m512 selector2 = _mm512_set1_ps( beta ); + + // Scale by alpha + c_float_0p0 = _mm512_mul_ps( selector1, c_float_0p0 ); + + c_float_1p0 = _mm512_mul_ps( selector1, c_float_1p0 ); + + c_float_2p0 = _mm512_mul_ps( selector1, c_float_2p0 ); + + // Scale C by beta. + if ( beta != 0 ) + { + memcpy( buf0, ( c + ( rs_c * 0 ) ), ( n0_rem * sizeof( float ) ) ); + memcpy( buf1, ( c + ( rs_c * 1 ) ), ( n0_rem * sizeof( float ) ) ); + memcpy( buf2, ( c + ( rs_c * 2 ) ), ( n0_rem * sizeof( float) ) ); + + // c[0,0-15] + selector1 = _mm512_loadu_ps( buf0 ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[1,0-15] + selector1 = _mm512_loadu_ps( buf1 ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + + // c[2,0-15] + selector1 = _mm512_loadu_ps( buf2 ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); + } + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_3xLT16: + { + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + memcpy( buf0, ( ( float* )post_ops_list_temp->op_args1 + + post_op_c_j ), ( n0_rem * sizeof( float ) ) ); + selector1 = _mm512_loadu_ps( buf0 ); + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); + } + else + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 1 ) ); + __m512 selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 2 ) ); + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector2, c_float_1p0 ); + + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector3, c_float_2p0 ); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_3xLT16: + { + selector1 = _mm512_setzero_ps(); + + // c[0,0-15] + c_float_0p0 = _mm512_max_ps( selector1, c_float_0p0 ); + + // c[1,0-15] + c_float_1p0 = _mm512_max_ps( selector1, c_float_1p0 ); + + // c[2,0-15] + c_float_2p0 = _mm512_max_ps( selector1, c_float_2p0 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_3xLT16: + { + selector1 = _mm512_setzero_ps(); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_0p0) + + // c[1, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_1p0) + + // c[2, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_2p0) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_DOWNSCALE_3xLT16: + { + // c[0, 0-15] + CVT_F32_BF16_LT16(c_float_0p0,0,0); + + // c[1, 0-15] + CVT_F32_BF16_LT16(c_float_1p0,1,0); + + // c[2, 0-15] + CVT_F32_BF16_LT16(c_float_2p0,2,0); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_3xLT16_DISABLE: + ; + + // Store the results. + // c[0,0-15] + _mm512_storeu_ps( buf0, c_float_0p0 ); + + // c[1,0-15] + _mm512_storeu_ps( buf1, c_float_1p0 ); + + // c[2,0-15] + _mm512_storeu_ps( buf2, c_float_2p0 ); + + // Memcpy partial parts. + // c[0,0-15] + memcpy( c + ( rs_c * 0 ) + ( 0*16 ), buf0, ( n0_rem * sizeof( float ) ) ); + + // c[1,0-15] + memcpy( c + ( rs_c * 1 ) + ( 0*16 ), buf1, ( n0_rem * sizeof( float ) ) ); + + // c[2,0-15] + memcpy( c + ( rs_c * 2 ) + ( 0*16 ), buf2, ( n0_rem * sizeof( float ) ) ); + +} + +// 2xlt16 bf16 fringe kernel +LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_2xlt16) +{ + static void* post_ops_labels[] = + { + &&POST_OPS_2xLT16_DISABLE, + &&POST_OPS_BIAS_2xLT16, + &&POST_OPS_RELU_2xLT16, + &&POST_OPS_RELU_SCALE_2xLT16, + &&POST_OPS_DOWNSCALE_2xLT16 + }; + dim_t k_full_pieces = k0 / 2; + dim_t k_partial_pieces = k0 % 2; + + int32_t a_kfringe_buf = 0; + + // B matrix storage bfloat type + __m512bh b0; + + // A matrix storage bfloat type + __m512bh a_bf16_0; + + // For corner cases. + float buf0[16]; + float buf1[16]; + + + // Registers to use for accumulating C. + __m512 c_float_0p0 = _mm512_setzero_ps(); + + __m512 c_float_1p0 = _mm512_setzero_ps(); + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + b0 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-15] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + + // Broadcast a[1,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-15] = a[1,kr:kr+2]*b[kr:kr+2,0-15] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_0, b0 ); + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + b0 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+2]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-15] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + + // Broadcast a[1,kr:kr+2]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-15] = a[1,kr:kr+2]*b[kr:kr+2,0-15] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_0, b0 ); + } + + // Load alpha and beta + __m512 selector1 = _mm512_set1_ps( alpha ); + __m512 selector2 = _mm512_set1_ps( beta ); + + // Scale by alpha + c_float_0p0 = _mm512_mul_ps( selector1, c_float_0p0 ); + + c_float_1p0 = _mm512_mul_ps( selector1, c_float_1p0 ); + + // Scale C by beta. + if ( beta != 0 ) + { + memcpy( buf0, ( c + ( rs_c * 0 ) ), ( n0_rem * sizeof( float ) ) ); + memcpy( buf1, ( c + ( rs_c * 1 ) ), ( n0_rem * sizeof( float) ) ); + + // c[0,0-15] + selector1 = _mm512_loadu_ps( buf0 ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[1,0-15] + selector1 = _mm512_loadu_ps( buf1 ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + } + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_2xLT16: + { + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + memcpy( buf0, ( ( float* )post_ops_list_temp->op_args1 + + post_op_c_j ), ( n0_rem * sizeof( float ) ) ); + selector1 = _mm512_loadu_ps( buf0 ); + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + } + else + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 1 ) ); + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector2, c_float_1p0 ); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_2xLT16: + { + selector1 = _mm512_setzero_ps(); + + // c[0,0-15] + c_float_0p0 = _mm512_max_ps( selector1, c_float_0p0 ); + + // c[1,0-15] + c_float_1p0 = _mm512_max_ps( selector1, c_float_1p0 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_2xLT16: + { + selector1 = _mm512_setzero_ps(); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_0p0) + + // c[1, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_1p0) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_DOWNSCALE_2xLT16: + { + // c[0, 0-15] + CVT_F32_BF16_LT16(c_float_0p0,0,0); + + // c[1, 0-15] + CVT_F32_BF16_LT16(c_float_1p0,1,0); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_2xLT16_DISABLE: + ; + + // Store the results. + // c[0,0-15] + _mm512_storeu_ps( buf0, c_float_0p0 ); + + // c[1,0-15] + _mm512_storeu_ps( buf1, c_float_1p0 ); + + // Memcpy partial parts. + // c[0,0-15] + memcpy( c + ( rs_c * 0 ) + ( 0*16 ), buf0, ( n0_rem * sizeof( float ) ) ); + + // c[1,0-15] + memcpy( c + ( rs_c * 1 ) + ( 0*16 ), buf1, ( n0_rem * sizeof( float ) ) ); + +} + +// 1xlt16 bf16 fringe kernel +LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_1xlt16) +{ + static void* post_ops_labels[] = + { + &&POST_OPS_1xLT16_DISABLE, + &&POST_OPS_BIAS_1xLT16, + &&POST_OPS_RELU_1xLT16, + &&POST_OPS_RELU_SCALE_1xLT16, + &&POST_OPS_DOWNSCALE_1xLT16 + }; + dim_t k_full_pieces = k0 / 2; + dim_t k_partial_pieces = k0 % 2; + + int32_t a_kfringe_buf = 0; + + // B matrix storage bfloat type + __m512bh b0; + + // A matrix storage bfloat type + __m512bh a_bf16_0; + + // For corner cases. + float buf0[16]; + + + // Registers to use for accumulating C. + __m512 c_float_0p0 = _mm512_setzero_ps(); + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + b0 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-15] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + b0 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+2]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-15] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + } + + // Load alpha and beta + __m512 selector1 = _mm512_set1_ps( alpha ); + __m512 selector2 = _mm512_set1_ps( beta ); + + // Scale by alpha + c_float_0p0 = _mm512_mul_ps( selector1, c_float_0p0 ); + + // Scale C by beta. + if ( beta != 0 ) + { + memcpy( buf0, ( c + ( rs_c * 0 ) ), ( n0_rem * sizeof( float ) ) ); + + // c[0,0-15] + selector1 = _mm512_loadu_ps( buf0 ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + } + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_1xLT16: + { + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + memcpy( buf0, ( ( float* )post_ops_list_temp->op_args1 + + post_op_c_j ), ( n0_rem * sizeof( float ) ) ); + selector1 = _mm512_loadu_ps( buf0 ); + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + } + else + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 0 ) ); + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_1xLT16: + { + selector1 = _mm512_setzero_ps(); + + // c[0,0-15] + c_float_0p0 = _mm512_max_ps( selector1, c_float_0p0 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_1xLT16: + { + selector1 = _mm512_setzero_ps(); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_0p0) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_DOWNSCALE_1xLT16: + { + // c[0, 0-15] + CVT_F32_BF16_LT16(c_float_0p0,0,0); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_1xLT16_DISABLE: + ; + + // Store the results. + // c[0,0-15] + _mm512_storeu_ps( buf0, c_float_0p0 ); + + // Memcpy partial parts. + // c[0,0-15] + memcpy( c + ( rs_c * 0 ) + ( 0*16 ), buf0, ( n0_rem * sizeof( float ) ) ); + +} + +// 5x16 bf16 kernel +LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_5x16) +{ + static void* post_ops_labels[] = + { + &&POST_OPS_5x16_DISABLE, + &&POST_OPS_BIAS_5x16, + &&POST_OPS_RELU_5x16, + &&POST_OPS_RELU_SCALE_5x16, + &&POST_OPS_DOWNSCALE_5x16 + }; + dim_t k_full_pieces = k0 / 2; + dim_t k_partial_pieces = k0 % 2; + + int32_t a_kfringe_buf = 0; + + // B matrix storage bfloat type + __m512bh b0; + + // A matrix storage bfloat type + __m512bh a_bf16_0; + + // Registers to use for accumulating C. + __m512 c_float_0p0 = _mm512_setzero_ps(); + + __m512 c_float_1p0 = _mm512_setzero_ps(); + + __m512 c_float_2p0 = _mm512_setzero_ps(); + + __m512 c_float_3p0 = _mm512_setzero_ps(); + + __m512 c_float_4p0 = _mm512_setzero_ps(); + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + b0 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-15] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + + // Broadcast a[1,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-15] = a[1,kr:kr+2]*b[kr:kr+2,0-15] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_0, b0 ); + + // Broadcast a[2,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[2,0-15] = a[2,kr:kr+2]*b[kr:kr+2,0-15] + c_float_2p0 = _mm512_dpbf16_ps( c_float_2p0, a_bf16_0, b0 ); + + // Broadcast a[3,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 3 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[3,0-15] = a[3,kr:kr+2]*b[kr:kr+2,0-15] + c_float_3p0 = _mm512_dpbf16_ps( c_float_3p0, a_bf16_0, b0 ); + + // Broadcast a[4,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 4 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[4,0-15] = a[4,kr:kr+2]*b[kr:kr+2,0-15] + c_float_4p0 = _mm512_dpbf16_ps( c_float_4p0, a_bf16_0, b0 ); + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + b0 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+2]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-15] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + + // Broadcast a[1,kr:kr+2]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-15] = a[1,kr:kr+2]*b[kr:kr+2,0-15] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_0, b0 ); + + // Broadcast a[2,kr:kr+2]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 2 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[2,0-15] = a[2,kr:kr+2]*b[kr:kr+2,0-15] + c_float_2p0 = _mm512_dpbf16_ps( c_float_2p0, a_bf16_0, b0 ); + + // Broadcast a[3,kr:kr+2]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 3 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[3,0-15] = a[3,kr:kr+2]*b[kr:kr+2,0-15] + c_float_3p0 = _mm512_dpbf16_ps( c_float_3p0, a_bf16_0, b0 ); + + // Broadcast a[4,kr:kr+2]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 4 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[4,0-15] = a[4,kr:kr+2]*b[kr:kr+2,0-15] + c_float_4p0 = _mm512_dpbf16_ps( c_float_4p0, a_bf16_0, b0 ); + } + + // Load alpha and beta + __m512 selector1 = _mm512_set1_ps( alpha ); + __m512 selector2 = _mm512_set1_ps( beta ); + + // Scale by alpha + c_float_0p0 = _mm512_mul_ps( selector1, c_float_0p0 ); + + c_float_1p0 = _mm512_mul_ps( selector1, c_float_1p0 ); + + c_float_2p0 = _mm512_mul_ps( selector1, c_float_2p0 ); + + c_float_3p0 = _mm512_mul_ps( selector1, c_float_3p0 ); + + c_float_4p0 = _mm512_mul_ps( selector1, c_float_4p0 ); + + // Scale C by beta. + if ( beta != 0 ) + { + // c[0,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * 0 ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[1,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * 1 ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + + // c[2,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * 2 ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); + + // c[3,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * 3 ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_3p0 = _mm512_add_ps( selector1, c_float_3p0 ); + + // c[4,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * 4 ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_4p0 = _mm512_add_ps( selector1, c_float_4p0 ); + } + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_5x16: + { + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + selector1 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_op_c_j ); + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); + + // c[3,0-15] + c_float_3p0 = _mm512_add_ps( selector1, c_float_3p0 ); + + // c[4,0-15] + c_float_4p0 = _mm512_add_ps( selector1, c_float_4p0 ); + } + else + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 1 ) ); + __m512 selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 2 ) ); + __m512 selector4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 3 ) ); + __m512 selector5 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 4 ) ); + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector2, c_float_1p0 ); + + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector3, c_float_2p0 ); + + // c[3,0-15] + c_float_3p0 = _mm512_add_ps( selector4, c_float_3p0 ); + + // c[4,0-15] + c_float_4p0 = _mm512_add_ps( selector5, c_float_4p0 ); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_5x16: + { + selector1 = _mm512_setzero_ps(); + + // c[0,0-15] + c_float_0p0 = _mm512_max_ps( selector1, c_float_0p0 ); + + // c[1,0-15] + c_float_1p0 = _mm512_max_ps( selector1, c_float_1p0 ); + + // c[2,0-15] + c_float_2p0 = _mm512_max_ps( selector1, c_float_2p0 ); + + // c[3,0-15] + c_float_3p0 = _mm512_max_ps( selector1, c_float_3p0 ); + + // c[4,0-15] + c_float_4p0 = _mm512_max_ps( selector1, c_float_4p0 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_5x16: + { + selector1 = _mm512_setzero_ps(); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_0p0) + + // c[1, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_1p0) + + // c[2, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_2p0) + + // c[3, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_3p0) + + // c[4, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_4p0) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_DOWNSCALE_5x16: + { + // c[0, 0-15] + CVT_F32_BF16(c_float_0p0,0,0); + + // c[1, 0-15] + CVT_F32_BF16(c_float_1p0,1,0); + + // c[2, 0-15] + CVT_F32_BF16(c_float_2p0,2,0); + + // c[3, 0-15] + CVT_F32_BF16(c_float_3p0,3,0); + + // c[4, 0-15] + CVT_F32_BF16(c_float_4p0,4,0); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_5x16_DISABLE: + ; + + // Store the results. + // c[0,0-15] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 0*16 ), c_float_0p0 ); + + // c[1,0-15] + _mm512_storeu_ps( c + ( rs_c * 1 ) + ( 0*16 ), c_float_1p0 ); + + // c[2,0-15] + _mm512_storeu_ps( c + ( rs_c * 2 ) + ( 0*16 ), c_float_2p0 ); + + // c[3,0-15] + _mm512_storeu_ps( c + ( rs_c * 3 ) + ( 0*16 ), c_float_3p0 ); + + // c[4,0-15] + _mm512_storeu_ps( c + ( rs_c * 4 ) + ( 0*16 ), c_float_4p0 ); +} + +// 4x16 bf16 kernel +LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_4x16) +{ + static void* post_ops_labels[] = + { + &&POST_OPS_4x16_DISABLE, + &&POST_OPS_BIAS_4x16, + &&POST_OPS_RELU_4x16, + &&POST_OPS_RELU_SCALE_4x16, + &&POST_OPS_DOWNSCALE_4x16 + }; + dim_t k_full_pieces = k0 / 2; + dim_t k_partial_pieces = k0 % 2; + + int32_t a_kfringe_buf = 0; + + // B matrix storage bfloat type + __m512bh b0; + + // A matrix storage bfloat type + __m512bh a_bf16_0; + + // Registers to use for accumulating C. + __m512 c_float_0p0 = _mm512_setzero_ps(); + + __m512 c_float_1p0 = _mm512_setzero_ps(); + + __m512 c_float_2p0 = _mm512_setzero_ps(); + + __m512 c_float_3p0 = _mm512_setzero_ps(); + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + b0 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-15] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + + // Broadcast a[1,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-15] = a[1,kr:kr+2]*b[kr:kr+2,0-15] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_0, b0 ); + + // Broadcast a[2,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[2,0-15] = a[2,kr:kr+2]*b[kr:kr+2,0-15] + c_float_2p0 = _mm512_dpbf16_ps( c_float_2p0, a_bf16_0, b0 ); + + // Broadcast a[3,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 3 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[3,0-15] = a[3,kr:kr+2]*b[kr:kr+2,0-15] + c_float_3p0 = _mm512_dpbf16_ps( c_float_3p0, a_bf16_0, b0 ); + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + b0 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+2]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-15] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + + // Broadcast a[1,kr:kr+2]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-15] = a[1,kr:kr+2]*b[kr:kr+2,0-15] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_0, b0 ); + + // Broadcast a[2,kr:kr+2]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 2 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[2,0-15] = a[2,kr:kr+2]*b[kr:kr+2,0-15] + c_float_2p0 = _mm512_dpbf16_ps( c_float_2p0, a_bf16_0, b0 ); + + // Broadcast a[3,kr:kr+2]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 3 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[3,0-15] = a[3,kr:kr+2]*b[kr:kr+2,0-15] + c_float_3p0 = _mm512_dpbf16_ps( c_float_3p0, a_bf16_0, b0 ); + } + + // Load alpha and beta + __m512 selector1 = _mm512_set1_ps( alpha ); + __m512 selector2 = _mm512_set1_ps( beta ); + + // Scale by alpha + c_float_0p0 = _mm512_mul_ps( selector1, c_float_0p0 ); + + c_float_1p0 = _mm512_mul_ps( selector1, c_float_1p0 ); + + c_float_2p0 = _mm512_mul_ps( selector1, c_float_2p0 ); + + c_float_3p0 = _mm512_mul_ps( selector1, c_float_3p0 ); + + // Scale C by beta. + if ( beta != 0 ) + { + // c[0,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * 0 ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[1,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * 1 ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + + // c[2,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * 2 ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); + + // c[3,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * 3 ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_3p0 = _mm512_add_ps( selector1, c_float_3p0 ); + } + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_4x16: + { + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + selector1 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_op_c_j ); + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); + + // c[3,0-15] + c_float_3p0 = _mm512_add_ps( selector1, c_float_3p0 ); + } + else + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 1 ) ); + __m512 selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 2 ) ); + __m512 selector4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 3 ) ); + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector2, c_float_1p0 ); + + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector3, c_float_2p0 ); + + // c[3,0-15] + c_float_3p0 = _mm512_add_ps( selector4, c_float_3p0 ); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_4x16: + { + selector1 = _mm512_setzero_ps(); + + // c[0,0-15] + c_float_0p0 = _mm512_max_ps( selector1, c_float_0p0 ); + + // c[1,0-15] + c_float_1p0 = _mm512_max_ps( selector1, c_float_1p0 ); + + // c[2,0-15] + c_float_2p0 = _mm512_max_ps( selector1, c_float_2p0 ); + + // c[3,0-15] + c_float_3p0 = _mm512_max_ps( selector1, c_float_3p0 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_4x16: + { + selector1 = _mm512_setzero_ps(); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_0p0) + + // c[1, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_1p0) + + // c[2, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_2p0) + + // c[3, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_3p0) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_DOWNSCALE_4x16: + { + // c[0, 0-15] + CVT_F32_BF16(c_float_0p0,0,0); + + // c[1, 0-15] + CVT_F32_BF16(c_float_1p0,1,0); + + // c[2, 0-15] + CVT_F32_BF16(c_float_2p0,2,0); + + // c[3, 0-15] + CVT_F32_BF16(c_float_3p0,3,0); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_4x16_DISABLE: + ; + + // Store the results. + // c[0,0-15] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 0*16 ), c_float_0p0 ); + + // c[1,0-15] + _mm512_storeu_ps( c + ( rs_c * 1 ) + ( 0*16 ), c_float_1p0 ); + + // c[2,0-15] + _mm512_storeu_ps( c + ( rs_c * 2 ) + ( 0*16 ), c_float_2p0 ); + + // c[3,0-15] + _mm512_storeu_ps( c + ( rs_c * 3 ) + ( 0*16 ), c_float_3p0 ); +} + +// 3x16 bf16 kernel +LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_3x16) +{ + static void* post_ops_labels[] = + { + &&POST_OPS_3x16_DISABLE, + &&POST_OPS_BIAS_3x16, + &&POST_OPS_RELU_3x16, + &&POST_OPS_RELU_SCALE_3x16, + &&POST_OPS_DOWNSCALE_3x16 + }; + dim_t k_full_pieces = k0 / 2; + dim_t k_partial_pieces = k0 % 2; + + int32_t a_kfringe_buf = 0; + + // B matrix storage bfloat type + __m512bh b0; + + // A matrix storage bfloat type + __m512bh a_bf16_0; + + // Registers to use for accumulating C. + __m512 c_float_0p0 = _mm512_setzero_ps(); + + __m512 c_float_1p0 = _mm512_setzero_ps(); + + __m512 c_float_2p0 = _mm512_setzero_ps(); + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + b0 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-15] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + + // Broadcast a[1,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-15] = a[1,kr:kr+2]*b[kr:kr+2,0-15] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_0, b0 ); + + // Broadcast a[2,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[2,0-15] = a[2,kr:kr+2]*b[kr:kr+2,0-15] + c_float_2p0 = _mm512_dpbf16_ps( c_float_2p0, a_bf16_0, b0 ); + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + b0 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+2]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-15] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + + // Broadcast a[1,kr:kr+2]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-15] = a[1,kr:kr+2]*b[kr:kr+2,0-15] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_0, b0 ); + + // Broadcast a[2,kr:kr+2]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 2 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[2,0-15] = a[2,kr:kr+2]*b[kr:kr+2,0-15] + c_float_2p0 = _mm512_dpbf16_ps( c_float_2p0, a_bf16_0, b0 ); + } + + // Load alpha and beta + __m512 selector1 = _mm512_set1_ps( alpha ); + __m512 selector2 = _mm512_set1_ps( beta ); + + // Scale by alpha + c_float_0p0 = _mm512_mul_ps( selector1, c_float_0p0 ); + + c_float_1p0 = _mm512_mul_ps( selector1, c_float_1p0 ); + + c_float_2p0 = _mm512_mul_ps( selector1, c_float_2p0 ); + + // Scale C by beta. + if ( beta != 0 ) + { + // c[0,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * 0 ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[1,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * 1 ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + + // c[2,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * 2 ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); + } + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_3x16: + { + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + selector1 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_op_c_j ); + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); + } + else + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 1 ) ); + __m512 selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 2 ) ); + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector2, c_float_1p0 ); + + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector3, c_float_2p0 ); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_3x16: + { + selector1 = _mm512_setzero_ps(); + + // c[0,0-15] + c_float_0p0 = _mm512_max_ps( selector1, c_float_0p0 ); + + // c[1,0-15] + c_float_1p0 = _mm512_max_ps( selector1, c_float_1p0 ); + + // c[2,0-15] + c_float_2p0 = _mm512_max_ps( selector1, c_float_2p0 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_3x16: + { + selector1 = _mm512_setzero_ps(); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_0p0) + + // c[1, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_1p0) + + // c[2, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_2p0) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_DOWNSCALE_3x16: + { + // c[0, 0-15] + CVT_F32_BF16(c_float_0p0,0,0); + + // c[1, 0-15] + CVT_F32_BF16(c_float_1p0,1,0); + + // c[2, 0-15] + CVT_F32_BF16(c_float_2p0,2,0); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_3x16_DISABLE: + ; + + + // Store the results. + // c[0,0-15] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 0*16 ), c_float_0p0 ); + + // c[1,0-15] + _mm512_storeu_ps( c + ( rs_c * 1 ) + ( 0*16 ), c_float_1p0 ); + + // c[2,0-15] + _mm512_storeu_ps( c + ( rs_c * 2 ) + ( 0*16 ), c_float_2p0 ); +} + +// 2x16 bf16 kernel +LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_2x16) +{ + static void* post_ops_labels[] = + { + &&POST_OPS_2x16_DISABLE, + &&POST_OPS_BIAS_2x16, + &&POST_OPS_RELU_2x16, + &&POST_OPS_RELU_SCALE_2x16, + &&POST_OPS_DOWNSCALE_2x16 + }; + dim_t k_full_pieces = k0 / 2; + dim_t k_partial_pieces = k0 % 2; + + int32_t a_kfringe_buf = 0; + + // B matrix storage bfloat type + __m512bh b0; + + // A matrix storage bfloat type + __m512bh a_bf16_0; + + // Registers to use for accumulating C. + __m512 c_float_0p0 = _mm512_setzero_ps(); + + __m512 c_float_1p0 = _mm512_setzero_ps(); + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + b0 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-15] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + + // Broadcast a[1,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-15] = a[1,kr:kr+2]*b[kr:kr+2,0-15] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_0, b0 ); + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + b0 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+2]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-15] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + + // Broadcast a[1,kr:kr+2]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-15] = a[1,kr:kr+2]*b[kr:kr+2,0-15] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_0, b0 ); + } + + // Load alpha and beta + __m512 selector1 = _mm512_set1_ps( alpha ); + __m512 selector2 = _mm512_set1_ps( beta ); + + // Scale by alpha + c_float_0p0 = _mm512_mul_ps( selector1, c_float_0p0 ); + + c_float_1p0 = _mm512_mul_ps( selector1, c_float_1p0 ); + + // Scale C by beta. + if ( beta != 0 ) + { + // c[0,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * 0 ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[1,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * 1 ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + } + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_2x16: + { + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + selector1 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_op_c_j ); + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + } + else + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 1 ) ); + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector2, c_float_1p0 ); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_2x16: + { + selector1 = _mm512_setzero_ps(); + + // c[0,0-15] + c_float_0p0 = _mm512_max_ps( selector1, c_float_0p0 ); + + // c[1,0-15] + c_float_1p0 = _mm512_max_ps( selector1, c_float_1p0 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_2x16: + { + selector1 = _mm512_setzero_ps(); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_0p0) + + // c[1, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_1p0) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_DOWNSCALE_2x16: + { + // c[0, 0-15] + CVT_F32_BF16(c_float_0p0,0,0); + + // c[1, 0-15] + CVT_F32_BF16(c_float_1p0,1,0); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_2x16_DISABLE: + ; + + // Store the results. + // c[0,0-15] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 0*16 ), c_float_0p0 ); + + // c[1,0-15] + _mm512_storeu_ps( c + ( rs_c * 1 ) + ( 0*16 ), c_float_1p0 ); +} + +// 1x16 bf16 kernel +LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_1x16) +{ + static void* post_ops_labels[] = + { + &&POST_OPS_1x16_DISABLE, + &&POST_OPS_BIAS_1x16, + &&POST_OPS_RELU_1x16, + &&POST_OPS_RELU_SCALE_1x16, + &&POST_OPS_DOWNSCALE_1x16 + }; + dim_t k_full_pieces = k0 / 2; + dim_t k_partial_pieces = k0 % 2; + + int32_t a_kfringe_buf = 0; + + // B matrix storage bfloat type + __m512bh b0; + + // A matrix storage bfloat type + __m512bh a_bf16_0; + + // Registers to use for accumulating C. + __m512 c_float_0p0 = _mm512_setzero_ps(); + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + b0 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-15] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + b0 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+2]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-15] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + } + + // Load alpha and beta + __m512 selector1 = _mm512_set1_ps( alpha ); + __m512 selector2 = _mm512_set1_ps( beta ); + + // Scale by alpha + c_float_0p0 = _mm512_mul_ps( selector1, c_float_0p0 ); + + // Scale C by beta. + if ( beta != 0 ) + { + // c[0,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * 0 ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + } + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_1x16: + { + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + selector1 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_op_c_j ); + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + } + else + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 0 ) ); + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_1x16: + { + selector1 = _mm512_setzero_ps(); + + // c[0,0-15] + c_float_0p0 = _mm512_max_ps( selector1, c_float_0p0 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_1x16: + { + selector1 = _mm512_setzero_ps(); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_0p0) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_DOWNSCALE_1x16: + { + // c[0, 0-15] + CVT_F32_BF16(c_float_0p0,0,0); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_1x16_DISABLE: + ; + + // Store the results. + // c[0,0-15] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 0*16 ), c_float_0p0 ); +} + +// 5x32 bf16 kernel +LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_5x32) +{ + static void* post_ops_labels[] = + { + &&POST_OPS_5x32_DISABLE, + &&POST_OPS_BIAS_5x32, + &&POST_OPS_RELU_5x32, + &&POST_OPS_RELU_SCALE_5x32, + &&POST_OPS_DOWNSCALE_5x32 + }; + dim_t k_full_pieces = k0 / 2; + dim_t k_partial_pieces = k0 % 2; + + int32_t a_kfringe_buf = 0; + + // B matrix storage bfloat type + __m512bh b0; + __m512bh b1; + + // A matrix storage bfloat type + __m512bh a_bf16_0; + + // Registers to use for accumulating C. + __m512 c_float_0p0 = _mm512_setzero_ps(); + __m512 c_float_0p1 = _mm512_setzero_ps(); + + __m512 c_float_1p0 = _mm512_setzero_ps(); + __m512 c_float_1p1 = _mm512_setzero_ps(); + + __m512 c_float_2p0 = _mm512_setzero_ps(); + __m512 c_float_2p1 = _mm512_setzero_ps(); + + __m512 c_float_3p0 = _mm512_setzero_ps(); + __m512 c_float_3p1 = _mm512_setzero_ps(); + + __m512 c_float_4p0 = _mm512_setzero_ps(); + __m512 c_float_4p1 = _mm512_setzero_ps(); + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + b0 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + b1 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 1 ) ); + + // Broadcast a[0,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-31] = a[0,kr:kr+2]*b[kr:kr+2,0-31] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + c_float_0p1 = _mm512_dpbf16_ps( c_float_0p1, a_bf16_0, b1 ); + + // Broadcast a[1,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-31] = a[1,kr:kr+2]*b[kr:kr+2,0-31] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_0, b0 ); + c_float_1p1 = _mm512_dpbf16_ps( c_float_1p1, a_bf16_0, b1 ); + + // Broadcast a[2,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[2,0-31] = a[2,kr:kr+2]*b[kr:kr+2,0-31] + c_float_2p0 = _mm512_dpbf16_ps( c_float_2p0, a_bf16_0, b0 ); + c_float_2p1 = _mm512_dpbf16_ps( c_float_2p1, a_bf16_0, b1 ); + + // Broadcast a[3,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 3 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[3,0-31] = a[3,kr:kr+2]*b[kr:kr+2,0-31] + c_float_3p0 = _mm512_dpbf16_ps( c_float_3p0, a_bf16_0, b0 ); + c_float_3p1 = _mm512_dpbf16_ps( c_float_3p1, a_bf16_0, b1 ); + + // Broadcast a[4,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 4 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[4,0-31] = a[4,kr:kr+2]*b[kr:kr+2,0-31] + c_float_4p0 = _mm512_dpbf16_ps( c_float_4p0, a_bf16_0, b0 ); + c_float_4p1 = _mm512_dpbf16_ps( c_float_4p1, a_bf16_0, b1 ); + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + b0 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + b1 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); + + // Broadcast a[0,kr:kr+2]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-31] = a[0,kr:kr+2]*b[kr:kr+2,0-31] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + c_float_0p1 = _mm512_dpbf16_ps( c_float_0p1, a_bf16_0, b1 ); + + // Broadcast a[1,kr:kr+2]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-31] = a[1,kr:kr+2]*b[kr:kr+2,0-31] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_0, b0 ); + c_float_1p1 = _mm512_dpbf16_ps( c_float_1p1, a_bf16_0, b1 ); + + // Broadcast a[2,kr:kr+2]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 2 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[2,0-31] = a[2,kr:kr+2]*b[kr:kr+2,0-31] + c_float_2p0 = _mm512_dpbf16_ps( c_float_2p0, a_bf16_0, b0 ); + c_float_2p1 = _mm512_dpbf16_ps( c_float_2p1, a_bf16_0, b1 ); + + // Broadcast a[3,kr:kr+2]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 3 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[3,0-31] = a[3,kr:kr+2]*b[kr:kr+2,0-31] + c_float_3p0 = _mm512_dpbf16_ps( c_float_3p0, a_bf16_0, b0 ); + c_float_3p1 = _mm512_dpbf16_ps( c_float_3p1, a_bf16_0, b1 ); + + // Broadcast a[4,kr:kr+2]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 4 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[4,0-31] = a[4,kr:kr+2]*b[kr:kr+2,0-31] + c_float_4p0 = _mm512_dpbf16_ps( c_float_4p0, a_bf16_0, b0 ); + c_float_4p1 = _mm512_dpbf16_ps( c_float_4p1, a_bf16_0, b1 ); + } + + // Load alpha and beta + __m512 selector1 = _mm512_set1_ps( alpha ); + __m512 selector2 = _mm512_set1_ps( beta ); + + // Scale by alpha + c_float_0p0 = _mm512_mul_ps( selector1, c_float_0p0 ); + c_float_0p1 = _mm512_mul_ps( selector1, c_float_0p1 ); + + c_float_1p0 = _mm512_mul_ps( selector1, c_float_1p0 ); + c_float_1p1 = _mm512_mul_ps( selector1, c_float_1p1 ); + + c_float_2p0 = _mm512_mul_ps( selector1, c_float_2p0 ); + c_float_2p1 = _mm512_mul_ps( selector1, c_float_2p1 ); + + c_float_3p0 = _mm512_mul_ps( selector1, c_float_3p0 ); + c_float_3p1 = _mm512_mul_ps( selector1, c_float_3p1 ); + + c_float_4p0 = _mm512_mul_ps( selector1, c_float_4p0 ); + c_float_4p1 = _mm512_mul_ps( selector1, c_float_4p1 ); + + // Scale C by beta. + if ( beta != 0 ) + { + // c[0,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * 0 ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + selector1 = _mm512_loadu_ps( c + ( rs_c * 0 ) + ( 1*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_0p1 = _mm512_add_ps( selector1, c_float_0p1 ); + + // c[1,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * 1 ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + + // c[1,16-31] + selector1 = _mm512_loadu_ps( c + ( rs_c * 1 ) + ( 1*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_1p1 = _mm512_add_ps( selector1, c_float_1p1 ); + + // c[2,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * 2 ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); + + // c[2,16-31] + selector1 = _mm512_loadu_ps( c + ( rs_c * 2 ) + ( 1*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_2p1 = _mm512_add_ps( selector1, c_float_2p1 ); + + // c[3,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * 3 ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_3p0 = _mm512_add_ps( selector1, c_float_3p0 ); + + // c[3,16-31] + selector1 = _mm512_loadu_ps( c + ( rs_c * 3 ) + ( 1*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_3p1 = _mm512_add_ps( selector1, c_float_3p1 ); + + // c[4,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * 4 ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_4p0 = _mm512_add_ps( selector1, c_float_4p0 ); + + // c[4,16-31] + selector1 = _mm512_loadu_ps( c + ( rs_c * 4 ) + ( 1*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_4p1 = _mm512_add_ps( selector1, c_float_4p1 ); + } + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_5x32: + { + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + selector1 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 1 * 16 ) ); + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_add_ps( selector2, c_float_0p1 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + + // c[1, 16-31] + c_float_1p1 = _mm512_add_ps( selector2, c_float_1p1 ); + + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); + + // c[2, 16-31] + c_float_2p1 = _mm512_add_ps( selector2, c_float_2p1 ); + + // c[3,0-15] + c_float_3p0 = _mm512_add_ps( selector1, c_float_3p0 ); + + // c[3, 16-31] + c_float_3p1 = _mm512_add_ps( selector2, c_float_3p1 ); + + // c[4,0-15] + c_float_4p0 = _mm512_add_ps( selector1, c_float_4p0 ); + + // c[4, 16-31] + c_float_4p1 = _mm512_add_ps( selector2, c_float_4p1 ); + } + else + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 1 ) ); + __m512 selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 2 ) ); + __m512 selector4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 3 ) ); + __m512 selector5 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 4 ) ); + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_add_ps( selector1, c_float_0p1 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector2, c_float_1p0 ); + + // c[1, 16-31] + c_float_1p1 = _mm512_add_ps( selector2, c_float_1p1 ); + + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector3, c_float_2p0 ); + + // c[2, 16-31] + c_float_2p1 = _mm512_add_ps( selector3, c_float_2p1 ); + + // c[3,0-15] + c_float_3p0 = _mm512_add_ps( selector4, c_float_3p0 ); + + // c[3, 16-31] + c_float_3p1 = _mm512_add_ps( selector4, c_float_3p1 ); + + // c[4,0-15] + c_float_4p0 = _mm512_add_ps( selector5, c_float_4p0 ); + + // c[4, 16-31] + c_float_4p1 = _mm512_add_ps( selector5, c_float_4p1 ); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_5x32: + { + selector1 = _mm512_setzero_ps(); + + // c[0,0-15] + c_float_0p0 = _mm512_max_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_max_ps( selector1, c_float_0p1 ); + + // c[1,0-15] + c_float_1p0 = _mm512_max_ps( selector1, c_float_1p0 ); + + // c[1,16-31] + c_float_1p1 = _mm512_max_ps( selector1, c_float_1p1 ); + + // c[2,0-15] + c_float_2p0 = _mm512_max_ps( selector1, c_float_2p0 ); + + // c[2,16-31] + c_float_2p1 = _mm512_max_ps( selector1, c_float_2p1 ); + + // c[3,0-15] + c_float_3p0 = _mm512_max_ps( selector1, c_float_3p0 ); + + // c[3,16-31] + c_float_3p1 = _mm512_max_ps( selector1, c_float_3p1 ); + + // c[4,0-15] + c_float_4p0 = _mm512_max_ps( selector1, c_float_4p0 ); + + // c[4,16-31] + c_float_4p1 = _mm512_max_ps( selector1, c_float_4p1 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_5x32: + { + selector1 = _mm512_setzero_ps(); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_0p0) + + // c[0, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_0p1) + + // c[1, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_1p0) + + // c[1, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_1p1) + + // c[2, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_2p0) + + // c[2, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_2p1) + + // c[3, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_3p0) + + // c[3, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_3p1) + + // c[4, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_4p0) + + // c[4, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_4p1) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_DOWNSCALE_5x32: + { + // c[0, 0-15] + CVT_F32_BF16(c_float_0p0,0,0); + + // c[0, 16-31] + CVT_F32_BF16(c_float_0p1,0,1); + + // c[1, 0-15] + CVT_F32_BF16(c_float_1p0,1,0); + + // c[1, 16-31] + CVT_F32_BF16(c_float_1p1,1,1); + + // c[2, 0-15] + CVT_F32_BF16(c_float_2p0,2,0); + + // c[2, 16-31] + CVT_F32_BF16(c_float_2p1,2,1); + + // c[3, 0-15] + CVT_F32_BF16(c_float_3p0,3,0); + + // c[3, 16-31] + CVT_F32_BF16(c_float_3p1,3,1); + + // c[4, 0-15] + CVT_F32_BF16(c_float_4p0,4,0); + + // c[4, 16-31] + CVT_F32_BF16(c_float_4p1,4,1); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_5x32_DISABLE: + ; + + // Store the results. + // c[0,0-15] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 0*16 ), c_float_0p0 ); + + // c[0, 16-31] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 1*16 ), c_float_0p1 ); + + // c[1,0-15] + _mm512_storeu_ps( c + ( rs_c * 1 ) + ( 0*16 ), c_float_1p0 ); + + // c[1,16-31] + _mm512_storeu_ps( c + ( rs_c * 1 ) + ( 1*16 ), c_float_1p1 ); + + // c[2,0-15] + _mm512_storeu_ps( c + ( rs_c * 2 ) + ( 0*16 ), c_float_2p0 ); + + // c[2,16-31] + _mm512_storeu_ps( c + ( rs_c * 2 ) + ( 1*16 ), c_float_2p1 ); + + // c[3,0-15] + _mm512_storeu_ps( c + ( rs_c * 3 ) + ( 0*16 ), c_float_3p0 ); + + // c[3,16-31] + _mm512_storeu_ps( c + ( rs_c * 3 ) + ( 1*16 ), c_float_3p1 ); + + // c[4,0-15] + _mm512_storeu_ps( c + ( rs_c * 4 ) + ( 0*16 ), c_float_4p0 ); + + // c[4,16-31] + _mm512_storeu_ps( c + ( rs_c * 4 ) + ( 1*16 ), c_float_4p1 ); +} + +// 4x32 bf16 kernel +LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_4x32) +{ + static void* post_ops_labels[] = + { + &&POST_OPS_4x32_DISABLE, + &&POST_OPS_BIAS_4x32, + &&POST_OPS_RELU_4x32, + &&POST_OPS_RELU_SCALE_4x32, + &&POST_OPS_DOWNSCALE_4x32 + }; + dim_t k_full_pieces = k0 / 2; + dim_t k_partial_pieces = k0 % 2; + + int32_t a_kfringe_buf = 0; + + // B matrix storage bfloat type + __m512bh b0; + __m512bh b1; + + // A matrix storage bfloat type + __m512bh a_bf16_0; + + // Registers to use for accumulating C. + __m512 c_float_0p0 = _mm512_setzero_ps(); + __m512 c_float_0p1 = _mm512_setzero_ps(); + + __m512 c_float_1p0 = _mm512_setzero_ps(); + __m512 c_float_1p1 = _mm512_setzero_ps(); + + __m512 c_float_2p0 = _mm512_setzero_ps(); + __m512 c_float_2p1 = _mm512_setzero_ps(); + + __m512 c_float_3p0 = _mm512_setzero_ps(); + __m512 c_float_3p1 = _mm512_setzero_ps(); + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + b0 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + b1 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 1 ) ); + + // Broadcast a[0,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-31] = a[0,kr:kr+2]*b[kr:kr+2,0-31] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + c_float_0p1 = _mm512_dpbf16_ps( c_float_0p1, a_bf16_0, b1 ); + + // Broadcast a[1,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-31] = a[1,kr:kr+2]*b[kr:kr+2,0-31] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_0, b0 ); + c_float_1p1 = _mm512_dpbf16_ps( c_float_1p1, a_bf16_0, b1 ); + + // Broadcast a[2,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[2,0-31] = a[2,kr:kr+2]*b[kr:kr+2,0-31] + c_float_2p0 = _mm512_dpbf16_ps( c_float_2p0, a_bf16_0, b0 ); + c_float_2p1 = _mm512_dpbf16_ps( c_float_2p1, a_bf16_0, b1 ); + + // Broadcast a[3,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 3 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[3,0-31] = a[3,kr:kr+2]*b[kr:kr+2,0-31] + c_float_3p0 = _mm512_dpbf16_ps( c_float_3p0, a_bf16_0, b0 ); + c_float_3p1 = _mm512_dpbf16_ps( c_float_3p1, a_bf16_0, b1 ); + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + b0 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + b1 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); + + // Broadcast a[0,kr:kr+2]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-31] = a[0,kr:kr+2]*b[kr:kr+2,0-31] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + c_float_0p1 = _mm512_dpbf16_ps( c_float_0p1, a_bf16_0, b1 ); + + // Broadcast a[1,kr:kr+2]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-31] = a[1,kr:kr+2]*b[kr:kr+2,0-31] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_0, b0 ); + c_float_1p1 = _mm512_dpbf16_ps( c_float_1p1, a_bf16_0, b1 ); + + // Broadcast a[2,kr:kr+2]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 2 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[2,0-31] = a[2,kr:kr+2]*b[kr:kr+2,0-31] + c_float_2p0 = _mm512_dpbf16_ps( c_float_2p0, a_bf16_0, b0 ); + c_float_2p1 = _mm512_dpbf16_ps( c_float_2p1, a_bf16_0, b1 ); + + // Broadcast a[3,kr:kr+2]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 3 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[3,0-31] = a[3,kr:kr+2]*b[kr:kr+2,0-31] + c_float_3p0 = _mm512_dpbf16_ps( c_float_3p0, a_bf16_0, b0 ); + c_float_3p1 = _mm512_dpbf16_ps( c_float_3p1, a_bf16_0, b1 ); + } + + // Load alpha and beta + __m512 selector1 = _mm512_set1_ps( alpha ); + __m512 selector2 = _mm512_set1_ps( beta ); + + // Scale by alpha + c_float_0p0 = _mm512_mul_ps( selector1, c_float_0p0 ); + c_float_0p1 = _mm512_mul_ps( selector1, c_float_0p1 ); + + c_float_1p0 = _mm512_mul_ps( selector1, c_float_1p0 ); + c_float_1p1 = _mm512_mul_ps( selector1, c_float_1p1 ); + + c_float_2p0 = _mm512_mul_ps( selector1, c_float_2p0 ); + c_float_2p1 = _mm512_mul_ps( selector1, c_float_2p1 ); + + c_float_3p0 = _mm512_mul_ps( selector1, c_float_3p0 ); + c_float_3p1 = _mm512_mul_ps( selector1, c_float_3p1 ); + + // Scale C by beta. + if ( beta != 0 ) + { + // c[0,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * 0 ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + selector1 = _mm512_loadu_ps( c + ( rs_c * 0 ) + ( 1*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_0p1 = _mm512_add_ps( selector1, c_float_0p1 ); + + // c[1,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * 1 ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + + // c[1,16-31] + selector1 = _mm512_loadu_ps( c + ( rs_c * 1 ) + ( 1*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_1p1 = _mm512_add_ps( selector1, c_float_1p1 ); + + // c[2,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * 2 ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); + + // c[2,16-31] + selector1 = _mm512_loadu_ps( c + ( rs_c * 2 ) + ( 1*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_2p1 = _mm512_add_ps( selector1, c_float_2p1 ); + + // c[3,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * 3 ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_3p0 = _mm512_add_ps( selector1, c_float_3p0 ); + + // c[3,16-31] + selector1 = _mm512_loadu_ps( c + ( rs_c * 3 ) + ( 1*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_3p1 = _mm512_add_ps( selector1, c_float_3p1 ); + } + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_4x32: + { + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + selector1 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 1 * 16 ) ); + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_add_ps( selector2, c_float_0p1 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + + // c[1, 16-31] + c_float_1p1 = _mm512_add_ps( selector2, c_float_1p1 ); + + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); + + // c[2, 16-31] + c_float_2p1 = _mm512_add_ps( selector2, c_float_2p1 ); + + // c[3,0-15] + c_float_3p0 = _mm512_add_ps( selector1, c_float_3p0 ); + + // c[3, 16-31] + c_float_3p1 = _mm512_add_ps( selector2, c_float_3p1 ); + } + else + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 1 ) ); + __m512 selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 2 ) ); + __m512 selector4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 3 ) ); + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_add_ps( selector1, c_float_0p1 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector2, c_float_1p0 ); + + // c[1, 16-31] + c_float_1p1 = _mm512_add_ps( selector2, c_float_1p1 ); + + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector3, c_float_2p0 ); + + // c[2, 16-31] + c_float_2p1 = _mm512_add_ps( selector3, c_float_2p1 ); + + // c[3,0-15] + c_float_3p0 = _mm512_add_ps( selector4, c_float_3p0 ); + + // c[3, 16-31] + c_float_3p1 = _mm512_add_ps( selector4, c_float_3p1 ); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_4x32: + { + selector1 = _mm512_setzero_ps(); + + // c[0,0-15] + c_float_0p0 = _mm512_max_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_max_ps( selector1, c_float_0p1 ); + + // c[1,0-15] + c_float_1p0 = _mm512_max_ps( selector1, c_float_1p0 ); + + // c[1,16-31] + c_float_1p1 = _mm512_max_ps( selector1, c_float_1p1 ); + + // c[2,0-15] + c_float_2p0 = _mm512_max_ps( selector1, c_float_2p0 ); + + // c[2,16-31] + c_float_2p1 = _mm512_max_ps( selector1, c_float_2p1 ); + + // c[3,0-15] + c_float_3p0 = _mm512_max_ps( selector1, c_float_3p0 ); + + // c[3,16-31] + c_float_3p1 = _mm512_max_ps( selector1, c_float_3p1 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_4x32: + { + selector1 = _mm512_setzero_ps(); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_0p0) + + // c[0, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_0p1) + + // c[1, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_1p0) + + // c[1, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_1p1) + + // c[2, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_2p0) + + // c[2, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_2p1) + + // c[3, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_3p0) + + // c[3, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_3p1) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_DOWNSCALE_4x32: + { + // c[0, 0-15] + CVT_F32_BF16(c_float_0p0,0,0); + + // c[0, 16-31] + CVT_F32_BF16(c_float_0p1,0,1); + + // c[1, 0-15] + CVT_F32_BF16(c_float_1p0,1,0); + + // c[1, 16-31] + CVT_F32_BF16(c_float_1p1,1,1); + + // c[2, 0-15] + CVT_F32_BF16(c_float_2p0,2,0); + + // c[2, 16-31] + CVT_F32_BF16(c_float_2p1,2,1); + + // c[3, 0-15] + CVT_F32_BF16(c_float_3p0,3,0); + + // c[3, 16-31] + CVT_F32_BF16(c_float_3p1,3,1); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_4x32_DISABLE: + ; + + // Store the results. + // c[0,0-15] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 0*16 ), c_float_0p0 ); + + // c[0, 16-31] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 1*16 ), c_float_0p1 ); + + // c[1,0-15] + _mm512_storeu_ps( c + ( rs_c * 1 ) + ( 0*16 ), c_float_1p0 ); + + // c[1,16-31] + _mm512_storeu_ps( c + ( rs_c * 1 ) + ( 1*16 ), c_float_1p1 ); + + // c[2,0-15] + _mm512_storeu_ps( c + ( rs_c * 2 ) + ( 0*16 ), c_float_2p0 ); + + // c[2,16-31] + _mm512_storeu_ps( c + ( rs_c * 2 ) + ( 1*16 ), c_float_2p1 ); + + // c[3,0-15] + _mm512_storeu_ps( c + ( rs_c * 3 ) + ( 0*16 ), c_float_3p0 ); + + // c[3,16-31] + _mm512_storeu_ps( c + ( rs_c * 3 ) + ( 1*16 ), c_float_3p1 ); +} + +// 3x32 bf16 kernel +LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_3x32) +{ + static void* post_ops_labels[] = + { + &&POST_OPS_3x32_DISABLE, + &&POST_OPS_BIAS_3x32, + &&POST_OPS_RELU_3x32, + &&POST_OPS_RELU_SCALE_3x32, + &&POST_OPS_DOWNSCALE_3x32 + }; + dim_t k_full_pieces = k0 / 2; + dim_t k_partial_pieces = k0 % 2; + + int32_t a_kfringe_buf = 0; + + // B matrix storage bfloat type + __m512bh b0; + __m512bh b1; + + // A matrix storage bfloat type + __m512bh a_bf16_0; + + // Registers to use for accumulating C. + __m512 c_float_0p0 = _mm512_setzero_ps(); + __m512 c_float_0p1 = _mm512_setzero_ps(); + + __m512 c_float_1p0 = _mm512_setzero_ps(); + __m512 c_float_1p1 = _mm512_setzero_ps(); + + __m512 c_float_2p0 = _mm512_setzero_ps(); + __m512 c_float_2p1 = _mm512_setzero_ps(); + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + b0 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + b1 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 1 ) ); + + // Broadcast a[0,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-31] = a[0,kr:kr+2]*b[kr:kr+2,0-31] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + c_float_0p1 = _mm512_dpbf16_ps( c_float_0p1, a_bf16_0, b1 ); + + // Broadcast a[1,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-31] = a[1,kr:kr+2]*b[kr:kr+2,0-31] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_0, b0 ); + c_float_1p1 = _mm512_dpbf16_ps( c_float_1p1, a_bf16_0, b1 ); + + // Broadcast a[2,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[2,0-31] = a[2,kr:kr+2]*b[kr:kr+2,0-31] + c_float_2p0 = _mm512_dpbf16_ps( c_float_2p0, a_bf16_0, b0 ); + c_float_2p1 = _mm512_dpbf16_ps( c_float_2p1, a_bf16_0, b1 ); + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + b0 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + b1 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); + + // Broadcast a[0,kr:kr+2]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-31] = a[0,kr:kr+2]*b[kr:kr+2,0-31] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + c_float_0p1 = _mm512_dpbf16_ps( c_float_0p1, a_bf16_0, b1 ); + + // Broadcast a[1,kr:kr+2]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-31] = a[1,kr:kr+2]*b[kr:kr+2,0-31] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_0, b0 ); + c_float_1p1 = _mm512_dpbf16_ps( c_float_1p1, a_bf16_0, b1 ); + + // Broadcast a[2,kr:kr+2]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 2 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[2,0-31] = a[2,kr:kr+2]*b[kr:kr+2,0-31] + c_float_2p0 = _mm512_dpbf16_ps( c_float_2p0, a_bf16_0, b0 ); + c_float_2p1 = _mm512_dpbf16_ps( c_float_2p1, a_bf16_0, b1 ); + } + + // Load alpha and beta + __m512 selector1 = _mm512_set1_ps( alpha ); + __m512 selector2 = _mm512_set1_ps( beta ); + + // Scale by alpha + c_float_0p0 = _mm512_mul_ps( selector1, c_float_0p0 ); + c_float_0p1 = _mm512_mul_ps( selector1, c_float_0p1 ); + + c_float_1p0 = _mm512_mul_ps( selector1, c_float_1p0 ); + c_float_1p1 = _mm512_mul_ps( selector1, c_float_1p1 ); + + c_float_2p0 = _mm512_mul_ps( selector1, c_float_2p0 ); + c_float_2p1 = _mm512_mul_ps( selector1, c_float_2p1 ); + + // Scale C by beta. + if ( beta != 0 ) + { + // c[0,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * 0 ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + selector1 = _mm512_loadu_ps( c + ( rs_c * 0 ) + ( 1*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_0p1 = _mm512_add_ps( selector1, c_float_0p1 ); + + // c[1,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * 1 ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + + // c[1,16-31] + selector1 = _mm512_loadu_ps( c + ( rs_c * 1 ) + ( 1*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_1p1 = _mm512_add_ps( selector1, c_float_1p1 ); + + // c[2,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * 2 ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); + + // c[2,16-31] + selector1 = _mm512_loadu_ps( c + ( rs_c * 2 ) + ( 1*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_2p1 = _mm512_add_ps( selector1, c_float_2p1 ); + } + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_3x32: + { + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + selector1 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 1 * 16 ) ); + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_add_ps( selector2, c_float_0p1 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + + // c[1, 16-31] + c_float_1p1 = _mm512_add_ps( selector2, c_float_1p1 ); + + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); + + // c[2, 16-31] + c_float_2p1 = _mm512_add_ps( selector2, c_float_2p1 ); + } + else + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 1 ) ); + __m512 selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 2 ) ); + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_add_ps( selector1, c_float_0p1 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector2, c_float_1p0 ); + + // c[1, 16-31] + c_float_1p1 = _mm512_add_ps( selector2, c_float_1p1 ); + + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector3, c_float_2p0 ); + + // c[2, 16-31] + c_float_2p1 = _mm512_add_ps( selector3, c_float_2p1 ); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_3x32: + { + selector1 = _mm512_setzero_ps(); + + // c[0,0-15] + c_float_0p0 = _mm512_max_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_max_ps( selector1, c_float_0p1 ); + + // c[1,0-15] + c_float_1p0 = _mm512_max_ps( selector1, c_float_1p0 ); + + // c[1,16-31] + c_float_1p1 = _mm512_max_ps( selector1, c_float_1p1 ); + + // c[2,0-15] + c_float_2p0 = _mm512_max_ps( selector1, c_float_2p0 ); + + // c[2,16-31] + c_float_2p1 = _mm512_max_ps( selector1, c_float_2p1 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_3x32: + { + selector1 = _mm512_setzero_ps(); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_0p0) + + // c[0, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_0p1) + + // c[1, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_1p0) + + // c[1, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_1p1) + + // c[2, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_2p0) + + // c[2, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_2p1) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_DOWNSCALE_3x32: + { + // c[0, 0-15] + CVT_F32_BF16(c_float_0p0,0,0); + + // c[0, 16-31] + CVT_F32_BF16(c_float_0p1,0,1); + + // c[1, 0-15] + CVT_F32_BF16(c_float_1p0,1,0); + + // c[1, 16-31] + CVT_F32_BF16(c_float_1p1,1,1); + + // c[2, 0-15] + CVT_F32_BF16(c_float_2p0,2,0); + + // c[2, 16-31] + CVT_F32_BF16(c_float_2p1,2,1); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_3x32_DISABLE: + ; + + // Store the results. + // c[0,0-15] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 0*16 ), c_float_0p0 ); + + // c[0, 16-31] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 1*16 ), c_float_0p1 ); + + // c[1,0-15] + _mm512_storeu_ps( c + ( rs_c * 1 ) + ( 0*16 ), c_float_1p0 ); + + // c[1,16-31] + _mm512_storeu_ps( c + ( rs_c * 1 ) + ( 1*16 ), c_float_1p1 ); + + // c[2,0-15] + _mm512_storeu_ps( c + ( rs_c * 2 ) + ( 0*16 ), c_float_2p0 ); + + // c[2,16-31] + _mm512_storeu_ps( c + ( rs_c * 2 ) + ( 1*16 ), c_float_2p1 ); +} + +// 2x32 bf16 kernel +LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_2x32) +{ + static void* post_ops_labels[] = + { + &&POST_OPS_2x32_DISABLE, + &&POST_OPS_BIAS_2x32, + &&POST_OPS_RELU_2x32, + &&POST_OPS_RELU_SCALE_2x32, + &&POST_OPS_DOWNSCALE_2x32 + }; + dim_t k_full_pieces = k0 / 2; + dim_t k_partial_pieces = k0 % 2; + + int32_t a_kfringe_buf = 0; + + // B matrix storage bfloat type + __m512bh b0; + __m512bh b1; + + // A matrix storage bfloat type + __m512bh a_bf16_0; + + // Registers to use for accumulating C. + __m512 c_float_0p0 = _mm512_setzero_ps(); + __m512 c_float_0p1 = _mm512_setzero_ps(); + + __m512 c_float_1p0 = _mm512_setzero_ps(); + __m512 c_float_1p1 = _mm512_setzero_ps(); + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + b0 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + b1 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 1 ) ); + + // Broadcast a[0,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-31] = a[0,kr:kr+2]*b[kr:kr+2,0-31] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + c_float_0p1 = _mm512_dpbf16_ps( c_float_0p1, a_bf16_0, b1 ); + + // Broadcast a[1,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-31] = a[1,kr:kr+2]*b[kr:kr+2,0-31] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_0, b0 ); + c_float_1p1 = _mm512_dpbf16_ps( c_float_1p1, a_bf16_0, b1 ); + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + b0 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + b1 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); + + // Broadcast a[0,kr:kr+2]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-31] = a[0,kr:kr+2]*b[kr:kr+2,0-31] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + c_float_0p1 = _mm512_dpbf16_ps( c_float_0p1, a_bf16_0, b1 ); + + // Broadcast a[1,kr:kr+2]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-31] = a[1,kr:kr+2]*b[kr:kr+2,0-31] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_0, b0 ); + c_float_1p1 = _mm512_dpbf16_ps( c_float_1p1, a_bf16_0, b1 ); + } + + // Load alpha and beta + __m512 selector1 = _mm512_set1_ps( alpha ); + __m512 selector2 = _mm512_set1_ps( beta ); + + // Scale by alpha + c_float_0p0 = _mm512_mul_ps( selector1, c_float_0p0 ); + c_float_0p1 = _mm512_mul_ps( selector1, c_float_0p1 ); + + c_float_1p0 = _mm512_mul_ps( selector1, c_float_1p0 ); + c_float_1p1 = _mm512_mul_ps( selector1, c_float_1p1 ); + + // Scale C by beta. + if ( beta != 0 ) + { + // c[0,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * 0 ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + selector1 = _mm512_loadu_ps( c + ( rs_c * 0 ) + ( 1*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_0p1 = _mm512_add_ps( selector1, c_float_0p1 ); + + // c[1,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * 1 ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + + // c[1,16-31] + selector1 = _mm512_loadu_ps( c + ( rs_c * 1 ) + ( 1*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_1p1 = _mm512_add_ps( selector1, c_float_1p1 ); + } + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_2x32: + { + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + selector1 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 1 * 16 ) ); + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_add_ps( selector2, c_float_0p1 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + + // c[1, 16-31] + c_float_1p1 = _mm512_add_ps( selector2, c_float_1p1 ); + } + else + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 1 ) ); + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_add_ps( selector1, c_float_0p1 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector2, c_float_1p0 ); + + // c[1, 16-31] + c_float_1p1 = _mm512_add_ps( selector2, c_float_1p1 ); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_2x32: + { + selector1 = _mm512_setzero_ps(); + + // c[0,0-15] + c_float_0p0 = _mm512_max_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_max_ps( selector1, c_float_0p1 ); + + // c[1,0-15] + c_float_1p0 = _mm512_max_ps( selector1, c_float_1p0 ); + + // c[1,16-31] + c_float_1p1 = _mm512_max_ps( selector1, c_float_1p1 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_2x32: + { + selector1 = _mm512_setzero_ps(); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_0p0) + + // c[0, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_0p1) + + // c[1, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_1p0) + + // c[1, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_1p1) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_DOWNSCALE_2x32: + { + // c[0, 0-15] + CVT_F32_BF16(c_float_0p0,0,0); + + // c[0, 16-31] + CVT_F32_BF16(c_float_0p1,0,1); + + // c[1, 0-15] + CVT_F32_BF16(c_float_1p0,1,0); + + // c[1, 16-31] + CVT_F32_BF16(c_float_1p1,1,1); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_2x32_DISABLE: + ; + + // Store the results. + // c[0,0-15] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 0*16 ), c_float_0p0 ); + + // c[0, 16-31] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 1*16 ), c_float_0p1 ); + + // c[1,0-15] + _mm512_storeu_ps( c + ( rs_c * 1 ) + ( 0*16 ), c_float_1p0 ); + + // c[1,16-31] + _mm512_storeu_ps( c + ( rs_c * 1 ) + ( 1*16 ), c_float_1p1 ); +} + +// 1x32 bf16 kernel +LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_1x32) +{ + static void* post_ops_labels[] = + { + &&POST_OPS_1x32_DISABLE, + &&POST_OPS_BIAS_1x32, + &&POST_OPS_RELU_1x32, + &&POST_OPS_RELU_SCALE_1x32, + &&POST_OPS_DOWNSCALE_1x32 + }; + dim_t k_full_pieces = k0 / 2; + dim_t k_partial_pieces = k0 % 2; + + int32_t a_kfringe_buf = 0; + + // B matrix storage bfloat type + __m512bh b0; + __m512bh b1; + + // A matrix storage bfloat type + __m512bh a_bf16_0; + + // Registers to use for accumulating C. + __m512 c_float_0p0 = _mm512_setzero_ps(); + __m512 c_float_0p1 = _mm512_setzero_ps(); + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + b0 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + b1 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 1 ) ); + + // Broadcast a[0,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-31] = a[0,kr:kr+2]*b[kr:kr+2,0-31] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + c_float_0p1 = _mm512_dpbf16_ps( c_float_0p1, a_bf16_0, b1 ); + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + b0 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + b1 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); + + // Broadcast a[0,kr:kr+2]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-31] = a[0,kr:kr+2]*b[kr:kr+2,0-31] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + c_float_0p1 = _mm512_dpbf16_ps( c_float_0p1, a_bf16_0, b1 ); + } + + // Load alpha and beta + __m512 selector1 = _mm512_set1_ps( alpha ); + __m512 selector2 = _mm512_set1_ps( beta ); + + // Scale by alpha + c_float_0p0 = _mm512_mul_ps( selector1, c_float_0p0 ); + c_float_0p1 = _mm512_mul_ps( selector1, c_float_0p1 ); + + // Scale C by beta. + if ( beta != 0 ) + { + // c[0,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * 0 ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + selector1 = _mm512_loadu_ps( c + ( rs_c * 0 ) + ( 1*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_0p1 = _mm512_add_ps( selector1, c_float_0p1 ); + } + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_1x32: + { + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + selector1 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 1 * 16 ) ); + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_add_ps( selector2, c_float_0p1 ); + } + else + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 0 ) ); + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_add_ps( selector1, c_float_0p1 ); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_1x32: + { + selector1 = _mm512_setzero_ps(); + + // c[0,0-15] + c_float_0p0 = _mm512_max_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_max_ps( selector1, c_float_0p1 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_1x32: + { + selector1 = _mm512_setzero_ps(); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_0p0) + + // c[0, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_0p1) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_DOWNSCALE_1x32: + { + // c[0, 0-15] + CVT_F32_BF16(c_float_0p0,0,0); + + // c[0, 16-31] + CVT_F32_BF16(c_float_0p1,0,1); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_1x32_DISABLE: + ; + + // Store the results. + // c[0,0-15] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 0*16 ), c_float_0p0 ); + + // c[0, 16-31] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 1*16 ), c_float_0p1 ); +} + +// 5x48 bf16 kernel +LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_5x48) +{ + static void* post_ops_labels[] = + { + &&POST_OPS_5x48_DISABLE, + &&POST_OPS_BIAS_5x48, + &&POST_OPS_RELU_5x48, + &&POST_OPS_RELU_SCALE_5x48, + &&POST_OPS_DOWNSCALE_5x48 + }; + dim_t k_full_pieces = k0 / 2; + dim_t k_partial_pieces = k0 % 2; + + int32_t a_kfringe_buf = 0; + + // B matrix storage bfloat type + __m512bh b0; + __m512bh b1; + __m512bh b2; + + // A matrix storage bfloat type + __m512bh a_bf16_0; + + // Registers to use for accumulating C. + __m512 c_float_0p0 = _mm512_setzero_ps(); + __m512 c_float_0p1 = _mm512_setzero_ps(); + __m512 c_float_0p2 = _mm512_setzero_ps(); + + __m512 c_float_1p0 = _mm512_setzero_ps(); + __m512 c_float_1p1 = _mm512_setzero_ps(); + __m512 c_float_1p2 = _mm512_setzero_ps(); + + __m512 c_float_2p0 = _mm512_setzero_ps(); + __m512 c_float_2p1 = _mm512_setzero_ps(); + __m512 c_float_2p2 = _mm512_setzero_ps(); + + __m512 c_float_3p0 = _mm512_setzero_ps(); + __m512 c_float_3p1 = _mm512_setzero_ps(); + __m512 c_float_3p2 = _mm512_setzero_ps(); + + __m512 c_float_4p0 = _mm512_setzero_ps(); + __m512 c_float_4p1 = _mm512_setzero_ps(); + __m512 c_float_4p2 = _mm512_setzero_ps(); + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + b0 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + b1 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 1 ) ); + b2 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 2 ) ); + + // Broadcast a[0,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-47] = a[0,kr:kr+2]*b[kr:kr+2,0-47] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + c_float_0p1 = _mm512_dpbf16_ps( c_float_0p1, a_bf16_0, b1 ); + c_float_0p2 = _mm512_dpbf16_ps( c_float_0p2, a_bf16_0, b2 ); + + // Broadcast a[1,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-47] = a[1,kr:kr+2]*b[kr:kr+2,0-47] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_0, b0 ); + c_float_1p1 = _mm512_dpbf16_ps( c_float_1p1, a_bf16_0, b1 ); + c_float_1p2 = _mm512_dpbf16_ps( c_float_1p2, a_bf16_0, b2 ); + + // Broadcast a[2,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[2,0-47] = a[2,kr:kr+2]*b[kr:kr+2,0-47] + c_float_2p0 = _mm512_dpbf16_ps( c_float_2p0, a_bf16_0, b0 ); + c_float_2p1 = _mm512_dpbf16_ps( c_float_2p1, a_bf16_0, b1 ); + c_float_2p2 = _mm512_dpbf16_ps( c_float_2p2, a_bf16_0, b2 ); + + // Broadcast a[3,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 3 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[3,0-47] = a[3,kr:kr+2]*b[kr:kr+2,0-47] + c_float_3p0 = _mm512_dpbf16_ps( c_float_3p0, a_bf16_0, b0 ); + c_float_3p1 = _mm512_dpbf16_ps( c_float_3p1, a_bf16_0, b1 ); + c_float_3p2 = _mm512_dpbf16_ps( c_float_3p2, a_bf16_0, b2 ); + + // Broadcast a[4,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 4 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[4,0-47] = a[4,kr:kr+2]*b[kr:kr+2,0-47] + c_float_4p0 = _mm512_dpbf16_ps( c_float_4p0, a_bf16_0, b0 ); + c_float_4p1 = _mm512_dpbf16_ps( c_float_4p1, a_bf16_0, b1 ); + c_float_4p2 = _mm512_dpbf16_ps( c_float_4p2, a_bf16_0, b2 ); + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + b0 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + b1 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); + b2 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 2 ) ); + + // Broadcast a[0,kr:kr+2]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-47] = a[0,kr:kr+2]*b[kr:kr+2,0-47] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + c_float_0p1 = _mm512_dpbf16_ps( c_float_0p1, a_bf16_0, b1 ); + c_float_0p2 = _mm512_dpbf16_ps( c_float_0p2, a_bf16_0, b2 ); + + // Broadcast a[1,kr:kr+2]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-47] = a[1,kr:kr+2]*b[kr:kr+2,0-47] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_0, b0 ); + c_float_1p1 = _mm512_dpbf16_ps( c_float_1p1, a_bf16_0, b1 ); + c_float_1p2 = _mm512_dpbf16_ps( c_float_1p2, a_bf16_0, b2 ); + + // Broadcast a[2,kr:kr+2]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 2 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[2,0-47] = a[2,kr:kr+2]*b[kr:kr+2,0-47] + c_float_2p0 = _mm512_dpbf16_ps( c_float_2p0, a_bf16_0, b0 ); + c_float_2p1 = _mm512_dpbf16_ps( c_float_2p1, a_bf16_0, b1 ); + c_float_2p2 = _mm512_dpbf16_ps( c_float_2p2, a_bf16_0, b2 ); + + // Broadcast a[3,kr:kr+2]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 3 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[3,0-47] = a[3,kr:kr+2]*b[kr:kr+2,0-47] + c_float_3p0 = _mm512_dpbf16_ps( c_float_3p0, a_bf16_0, b0 ); + c_float_3p1 = _mm512_dpbf16_ps( c_float_3p1, a_bf16_0, b1 ); + c_float_3p2 = _mm512_dpbf16_ps( c_float_3p2, a_bf16_0, b2 ); + + // Broadcast a[4,kr:kr+2]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 4 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[4,0-47] = a[4,kr:kr+2]*b[kr:kr+2,0-47] + c_float_4p0 = _mm512_dpbf16_ps( c_float_4p0, a_bf16_0, b0 ); + c_float_4p1 = _mm512_dpbf16_ps( c_float_4p1, a_bf16_0, b1 ); + c_float_4p2 = _mm512_dpbf16_ps( c_float_4p2, a_bf16_0, b2 ); + } + + // Load alpha and beta + __m512 selector1 = _mm512_set1_ps( alpha ); + __m512 selector2 = _mm512_set1_ps( beta ); + + // Scale by alpha + c_float_0p0 = _mm512_mul_ps( selector1, c_float_0p0 ); + c_float_0p1 = _mm512_mul_ps( selector1, c_float_0p1 ); + c_float_0p2 = _mm512_mul_ps( selector1, c_float_0p2 ); + + c_float_1p0 = _mm512_mul_ps( selector1, c_float_1p0 ); + c_float_1p1 = _mm512_mul_ps( selector1, c_float_1p1 ); + c_float_1p2 = _mm512_mul_ps( selector1, c_float_1p2 ); + + c_float_2p0 = _mm512_mul_ps( selector1, c_float_2p0 ); + c_float_2p1 = _mm512_mul_ps( selector1, c_float_2p1 ); + c_float_2p2 = _mm512_mul_ps( selector1, c_float_2p2 ); + + c_float_3p0 = _mm512_mul_ps( selector1, c_float_3p0 ); + c_float_3p1 = _mm512_mul_ps( selector1, c_float_3p1 ); + c_float_3p2 = _mm512_mul_ps( selector1, c_float_3p2 ); + + c_float_4p0 = _mm512_mul_ps( selector1, c_float_4p0 ); + c_float_4p1 = _mm512_mul_ps( selector1, c_float_4p1 ); + c_float_4p2 = _mm512_mul_ps( selector1, c_float_4p2 ); + + // Scale C by beta. + if ( beta != 0 ) + { + // c[0,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * 0 ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + selector1 = _mm512_loadu_ps( c + ( rs_c * 0 ) + ( 1*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_0p1 = _mm512_add_ps( selector1, c_float_0p1 ); + + // c[0,32-47] + selector1 = _mm512_loadu_ps( c + ( rs_c * 0 ) + ( 2*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_0p2 = _mm512_add_ps( selector1, c_float_0p2 ); + + // c[1,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * 1 ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + + // c[1,16-31] + selector1 = _mm512_loadu_ps( c + ( rs_c * 1 ) + ( 1*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_1p1 = _mm512_add_ps( selector1, c_float_1p1 ); + + // c[1,32-47] + selector1 = _mm512_loadu_ps( c + ( rs_c * 1 ) + ( 2*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_1p2 = _mm512_add_ps( selector1, c_float_1p2 ); + + // c[2,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * 2 ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); + + // c[2,16-31] + selector1 = _mm512_loadu_ps( c + ( rs_c * 2 ) + ( 1*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_2p1 = _mm512_add_ps( selector1, c_float_2p1 ); + + // c[2,32-47] + selector1 = _mm512_loadu_ps( c + ( rs_c * 2 ) + ( 2*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_2p2 = _mm512_add_ps( selector1, c_float_2p2 ); + + // c[3,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * 3 ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_3p0 = _mm512_add_ps( selector1, c_float_3p0 ); + + // c[3,16-31] + selector1 = _mm512_loadu_ps( c + ( rs_c * 3 ) + ( 1*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_3p1 = _mm512_add_ps( selector1, c_float_3p1 ); + + // c[3,32-47] + selector1 = _mm512_loadu_ps( c + ( rs_c * 3 ) + ( 2*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_3p2 = _mm512_add_ps( selector1, c_float_3p2 ); + + // c[4,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * 4 ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_4p0 = _mm512_add_ps( selector1, c_float_4p0 ); + + // c[4,16-31] + selector1 = _mm512_loadu_ps( c + ( rs_c * 4 ) + ( 1*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_4p1 = _mm512_add_ps( selector1, c_float_4p1 ); + + // c[4,32-47] + selector1 = _mm512_loadu_ps( c + ( rs_c * 4 ) + ( 2*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_4p2 = _mm512_add_ps( selector1, c_float_4p2 ); + } + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_5x48: + { + __m512 selector3; + + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + selector1 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 1 * 16 ) ); + selector3 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 2 * 16 ) ); + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_add_ps( selector2, c_float_0p1 ); + + // c[0,32-47] + c_float_0p2 = _mm512_add_ps( selector3, c_float_0p2 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + + // c[1, 16-31] + c_float_1p1 = _mm512_add_ps( selector2, c_float_1p1 ); + + // c[1,32-47] + c_float_1p2 = _mm512_add_ps( selector3, c_float_1p2 ); + + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); + + // c[2, 16-31] + c_float_2p1 = _mm512_add_ps( selector2, c_float_2p1 ); + + // c[2,32-47] + c_float_2p2 = _mm512_add_ps( selector3, c_float_2p2 ); + + // c[3,0-15] + c_float_3p0 = _mm512_add_ps( selector1, c_float_3p0 ); + + // c[3, 16-31] + c_float_3p1 = _mm512_add_ps( selector2, c_float_3p1 ); + + // c[3,32-47] + c_float_3p2 = _mm512_add_ps( selector3, c_float_3p2 ); + + // c[4,0-15] + c_float_4p0 = _mm512_add_ps( selector1, c_float_4p0 ); + + // c[4, 16-31] + c_float_4p1 = _mm512_add_ps( selector2, c_float_4p1 ); + + // c[4,32-47] + c_float_4p2 = _mm512_add_ps( selector3, c_float_4p2 ); + } + else + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 1 ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 2 ) ); + __m512 selector4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 3 ) ); + __m512 selector5 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 4 ) ); + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_add_ps( selector1, c_float_0p1 ); + + // c[0,32-47] + c_float_0p2 = _mm512_add_ps( selector1, c_float_0p2 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector2, c_float_1p0 ); + + // c[1, 16-31] + c_float_1p1 = _mm512_add_ps( selector2, c_float_1p1 ); + + // c[1,32-47] + c_float_1p2 = _mm512_add_ps( selector2, c_float_1p2 ); + + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector3, c_float_2p0 ); + + // c[2, 16-31] + c_float_2p1 = _mm512_add_ps( selector3, c_float_2p1 ); + + // c[2,32-47] + c_float_2p2 = _mm512_add_ps( selector3, c_float_2p2 ); + + // c[3,0-15] + c_float_3p0 = _mm512_add_ps( selector4, c_float_3p0 ); + + // c[3, 16-31] + c_float_3p1 = _mm512_add_ps( selector4, c_float_3p1 ); + + // c[3,32-47] + c_float_3p2 = _mm512_add_ps( selector4, c_float_3p2 ); + + // c[4,0-15] + c_float_4p0 = _mm512_add_ps( selector5, c_float_4p0 ); + + // c[4, 16-31] + c_float_4p1 = _mm512_add_ps( selector5, c_float_4p1 ); + + // c[4,32-47] + c_float_4p2 = _mm512_add_ps( selector5, c_float_4p2 ); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_5x48: + { + selector1 = _mm512_setzero_ps(); + + // c[0,0-15] + c_float_0p0 = _mm512_max_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_max_ps( selector1, c_float_0p1 ); + + // c[0,32-47] + c_float_0p2 = _mm512_max_ps( selector1, c_float_0p2 ); + + // c[1,0-15] + c_float_1p0 = _mm512_max_ps( selector1, c_float_1p0 ); + + // c[1,16-31] + c_float_1p1 = _mm512_max_ps( selector1, c_float_1p1 ); + + // c[1,32-47] + c_float_1p2 = _mm512_max_ps( selector1, c_float_1p2 ); + + // c[2,0-15] + c_float_2p0 = _mm512_max_ps( selector1, c_float_2p0 ); + + // c[2,16-31] + c_float_2p1 = _mm512_max_ps( selector1, c_float_2p1 ); + + // c[2,32-47] + c_float_2p2 = _mm512_max_ps( selector1, c_float_2p2 ); + + // c[3,0-15] + c_float_3p0 = _mm512_max_ps( selector1, c_float_3p0 ); + + // c[3,16-31] + c_float_3p1 = _mm512_max_ps( selector1, c_float_3p1 ); + + // c[3,32-47] + c_float_3p2 = _mm512_max_ps( selector1, c_float_3p2 ); + + // c[4,0-15] + c_float_4p0 = _mm512_max_ps( selector1, c_float_4p0 ); + + // c[4,16-31] + c_float_4p1 = _mm512_max_ps( selector1, c_float_4p1 ); + + // c[4,32-47] + c_float_4p2 = _mm512_max_ps( selector1, c_float_4p2 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_5x48: + { + selector1 = _mm512_setzero_ps(); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_0p0) + + // c[0, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_0p1) + + // c[0, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_0p2) + + // c[1, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_1p0) + + // c[1, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_1p1) + + // c[1, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_1p2) + + // c[2, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_2p0) + + // c[2, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_2p1) + + // c[2, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_2p2) + + // c[3, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_3p0) + + // c[3, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_3p1) + + // c[3, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_3p2) + + // c[4, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_4p0) + + // c[4, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_4p1) + + // c[4, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_4p2) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_DOWNSCALE_5x48: + { + // c[0, 0-15] + CVT_F32_BF16(c_float_0p0,0,0); + + // c[0, 16-31] + CVT_F32_BF16(c_float_0p1,0,1); + + // c[0, 32-47] + CVT_F32_BF16(c_float_0p2,0,2); + + // c[1, 0-15] + CVT_F32_BF16(c_float_1p0,1,0); + + // c[1, 16-31] + CVT_F32_BF16(c_float_1p1,1,1); + + // c[1, 32-47] + CVT_F32_BF16(c_float_1p2,1,2); + + // c[2, 0-15] + CVT_F32_BF16(c_float_2p0,2,0); + + // c[2, 16-31] + CVT_F32_BF16(c_float_2p1,2,1); + + // c[2, 32-47] + CVT_F32_BF16(c_float_2p2,2,2); + + // c[3, 0-15] + CVT_F32_BF16(c_float_3p0,3,0); + + // c[3, 16-31] + CVT_F32_BF16(c_float_3p1,3,1); + + // c[3, 32-47] + CVT_F32_BF16(c_float_3p2,3,2); + + // c[4, 0-15] + CVT_F32_BF16(c_float_4p0,4,0); + + // c[4, 16-31] + CVT_F32_BF16(c_float_4p1,4,1); + + // c[4, 32-47] + CVT_F32_BF16(c_float_4p2,4,2); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_5x48_DISABLE: + ; + + // Store the results. + // c[0,0-15] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 0*16 ), c_float_0p0 ); + + // c[0, 16-31] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 1*16 ), c_float_0p1 ); + + // c[0,32-47] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 2*16 ), c_float_0p2 ); + + // c[1,0-15] + _mm512_storeu_ps( c + ( rs_c * 1 ) + ( 0*16 ), c_float_1p0 ); + + // c[1,16-31] + _mm512_storeu_ps( c + ( rs_c * 1 ) + ( 1*16 ), c_float_1p1 ); + + // c[1,32-47] + _mm512_storeu_ps( c + ( rs_c * 1 ) + ( 2*16 ), c_float_1p2 ); + + // c[2,0-15] + _mm512_storeu_ps( c + ( rs_c * 2 ) + ( 0*16 ), c_float_2p0 ); + + // c[2,16-31] + _mm512_storeu_ps( c + ( rs_c * 2 ) + ( 1*16 ), c_float_2p1 ); + + // c[2,32-47] + _mm512_storeu_ps( c + ( rs_c * 2 ) + ( 2*16 ), c_float_2p2 ); + + // c[3,0-15] + _mm512_storeu_ps( c + ( rs_c * 3 ) + ( 0*16 ), c_float_3p0 ); + + // c[3,16-31] + _mm512_storeu_ps( c + ( rs_c * 3 ) + ( 1*16 ), c_float_3p1 ); + + // c[3,32-47] + _mm512_storeu_ps( c + ( rs_c * 3 ) + ( 2*16 ), c_float_3p2 ); + + // c[4,0-15] + _mm512_storeu_ps( c + ( rs_c * 4 ) + ( 0*16 ), c_float_4p0 ); + + // c[4,16-31] + _mm512_storeu_ps( c + ( rs_c * 4 ) + ( 1*16 ), c_float_4p1 ); + + // c[4,32-47] + _mm512_storeu_ps( c + ( rs_c * 4 ) + ( 2*16 ), c_float_4p2 ); +} + +// 4x48 bf16 kernel +LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_4x48) +{ + static void* post_ops_labels[] = + { + &&POST_OPS_4x48_DISABLE, + &&POST_OPS_BIAS_4x48, + &&POST_OPS_RELU_4x48, + &&POST_OPS_RELU_SCALE_4x48, + &&POST_OPS_DOWNSCALE_4x48 + }; + dim_t k_full_pieces = k0 / 2; + dim_t k_partial_pieces = k0 % 2; + + int32_t a_kfringe_buf = 0; + + // B matrix storage bfloat type + __m512bh b0; + __m512bh b1; + __m512bh b2; + + // A matrix storage bfloat type + __m512bh a_bf16_0; + + // Registers to use for accumulating C. + __m512 c_float_0p0 = _mm512_setzero_ps(); + __m512 c_float_0p1 = _mm512_setzero_ps(); + __m512 c_float_0p2 = _mm512_setzero_ps(); + + __m512 c_float_1p0 = _mm512_setzero_ps(); + __m512 c_float_1p1 = _mm512_setzero_ps(); + __m512 c_float_1p2 = _mm512_setzero_ps(); + + __m512 c_float_2p0 = _mm512_setzero_ps(); + __m512 c_float_2p1 = _mm512_setzero_ps(); + __m512 c_float_2p2 = _mm512_setzero_ps(); + + __m512 c_float_3p0 = _mm512_setzero_ps(); + __m512 c_float_3p1 = _mm512_setzero_ps(); + __m512 c_float_3p2 = _mm512_setzero_ps(); + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + b0 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + b1 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 1 ) ); + b2 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 2 ) ); + + // Broadcast a[0,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-47] = a[0,kr:kr+2]*b[kr:kr+2,0-47] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + c_float_0p1 = _mm512_dpbf16_ps( c_float_0p1, a_bf16_0, b1 ); + c_float_0p2 = _mm512_dpbf16_ps( c_float_0p2, a_bf16_0, b2 ); + + // Broadcast a[1,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-47] = a[1,kr:kr+2]*b[kr:kr+2,0-47] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_0, b0 ); + c_float_1p1 = _mm512_dpbf16_ps( c_float_1p1, a_bf16_0, b1 ); + c_float_1p2 = _mm512_dpbf16_ps( c_float_1p2, a_bf16_0, b2 ); + + // Broadcast a[2,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[2,0-47] = a[2,kr:kr+2]*b[kr:kr+2,0-47] + c_float_2p0 = _mm512_dpbf16_ps( c_float_2p0, a_bf16_0, b0 ); + c_float_2p1 = _mm512_dpbf16_ps( c_float_2p1, a_bf16_0, b1 ); + c_float_2p2 = _mm512_dpbf16_ps( c_float_2p2, a_bf16_0, b2 ); + + // Broadcast a[3,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 3 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[3,0-47] = a[3,kr:kr+2]*b[kr:kr+2,0-47] + c_float_3p0 = _mm512_dpbf16_ps( c_float_3p0, a_bf16_0, b0 ); + c_float_3p1 = _mm512_dpbf16_ps( c_float_3p1, a_bf16_0, b1 ); + c_float_3p2 = _mm512_dpbf16_ps( c_float_3p2, a_bf16_0, b2 ); + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + b0 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + b1 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); + b2 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 2 ) ); + + // Broadcast a[0,kr:kr+2]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-47] = a[0,kr:kr+2]*b[kr:kr+2,0-47] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + c_float_0p1 = _mm512_dpbf16_ps( c_float_0p1, a_bf16_0, b1 ); + c_float_0p2 = _mm512_dpbf16_ps( c_float_0p2, a_bf16_0, b2 ); + + // Broadcast a[1,kr:kr+2]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-47] = a[1,kr:kr+2]*b[kr:kr+2,0-47] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_0, b0 ); + c_float_1p1 = _mm512_dpbf16_ps( c_float_1p1, a_bf16_0, b1 ); + c_float_1p2 = _mm512_dpbf16_ps( c_float_1p2, a_bf16_0, b2 ); + + // Broadcast a[2,kr:kr+2]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 2 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[2,0-47] = a[2,kr:kr+2]*b[kr:kr+2,0-47] + c_float_2p0 = _mm512_dpbf16_ps( c_float_2p0, a_bf16_0, b0 ); + c_float_2p1 = _mm512_dpbf16_ps( c_float_2p1, a_bf16_0, b1 ); + c_float_2p2 = _mm512_dpbf16_ps( c_float_2p2, a_bf16_0, b2 ); + + // Broadcast a[3,kr:kr+2]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 3 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[3,0-47] = a[3,kr:kr+2]*b[kr:kr+2,0-47] + c_float_3p0 = _mm512_dpbf16_ps( c_float_3p0, a_bf16_0, b0 ); + c_float_3p1 = _mm512_dpbf16_ps( c_float_3p1, a_bf16_0, b1 ); + c_float_3p2 = _mm512_dpbf16_ps( c_float_3p2, a_bf16_0, b2 ); + } + + // Load alpha and beta + __m512 selector1 = _mm512_set1_ps( alpha ); + __m512 selector2 = _mm512_set1_ps( beta ); + + // Scale by alpha + c_float_0p0 = _mm512_mul_ps( selector1, c_float_0p0 ); + c_float_0p1 = _mm512_mul_ps( selector1, c_float_0p1 ); + c_float_0p2 = _mm512_mul_ps( selector1, c_float_0p2 ); + + c_float_1p0 = _mm512_mul_ps( selector1, c_float_1p0 ); + c_float_1p1 = _mm512_mul_ps( selector1, c_float_1p1 ); + c_float_1p2 = _mm512_mul_ps( selector1, c_float_1p2 ); + + c_float_2p0 = _mm512_mul_ps( selector1, c_float_2p0 ); + c_float_2p1 = _mm512_mul_ps( selector1, c_float_2p1 ); + c_float_2p2 = _mm512_mul_ps( selector1, c_float_2p2 ); + + c_float_3p0 = _mm512_mul_ps( selector1, c_float_3p0 ); + c_float_3p1 = _mm512_mul_ps( selector1, c_float_3p1 ); + c_float_3p2 = _mm512_mul_ps( selector1, c_float_3p2 ); + + // Scale C by beta. + if ( beta != 0 ) + { + // c[0,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * 0 ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + selector1 = _mm512_loadu_ps( c + ( rs_c * 0 ) + ( 1*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_0p1 = _mm512_add_ps( selector1, c_float_0p1 ); + + // c[0,32-47] + selector1 = _mm512_loadu_ps( c + ( rs_c * 0 ) + ( 2*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_0p2 = _mm512_add_ps( selector1, c_float_0p2 ); + + // c[1,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * 1 ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + + // c[1,16-31] + selector1 = _mm512_loadu_ps( c + ( rs_c * 1 ) + ( 1*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_1p1 = _mm512_add_ps( selector1, c_float_1p1 ); + + // c[1,32-47] + selector1 = _mm512_loadu_ps( c + ( rs_c * 1 ) + ( 2*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_1p2 = _mm512_add_ps( selector1, c_float_1p2 ); + + // c[2,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * 2 ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); + + // c[2,16-31] + selector1 = _mm512_loadu_ps( c + ( rs_c * 2 ) + ( 1*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_2p1 = _mm512_add_ps( selector1, c_float_2p1 ); + + // c[2,32-47] + selector1 = _mm512_loadu_ps( c + ( rs_c * 2 ) + ( 2*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_2p2 = _mm512_add_ps( selector1, c_float_2p2 ); + + // c[3,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * 3 ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_3p0 = _mm512_add_ps( selector1, c_float_3p0 ); + + // c[3,16-31] + selector1 = _mm512_loadu_ps( c + ( rs_c * 3 ) + ( 1*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_3p1 = _mm512_add_ps( selector1, c_float_3p1 ); + + // c[3,32-47] + selector1 = _mm512_loadu_ps( c + ( rs_c * 3 ) + ( 2*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_3p2 = _mm512_add_ps( selector1, c_float_3p2 ); + } + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_4x48: + { + __m512 selector3; + + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + selector1 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 1 * 16 ) ); + selector3 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 2 * 16 ) ); + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_add_ps( selector2, c_float_0p1 ); + + // c[0,32-47] + c_float_0p2 = _mm512_add_ps( selector3, c_float_0p2 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + + // c[1, 16-31] + c_float_1p1 = _mm512_add_ps( selector2, c_float_1p1 ); + + // c[1,32-47] + c_float_1p2 = _mm512_add_ps( selector3, c_float_1p2 ); + + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); + + // c[2, 16-31] + c_float_2p1 = _mm512_add_ps( selector2, c_float_2p1 ); + + // c[2,32-47] + c_float_2p2 = _mm512_add_ps( selector3, c_float_2p2 ); + + // c[3,0-15] + c_float_3p0 = _mm512_add_ps( selector1, c_float_3p0 ); + + // c[3, 16-31] + c_float_3p1 = _mm512_add_ps( selector2, c_float_3p1 ); + + // c[3,32-47] + c_float_3p2 = _mm512_add_ps( selector3, c_float_3p2 ); + } + else + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 1 ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 2 ) ); + __m512 selector4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 3 ) ); + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_add_ps( selector1, c_float_0p1 ); + + // c[0,32-47] + c_float_0p2 = _mm512_add_ps( selector1, c_float_0p2 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector2, c_float_1p0 ); + + // c[1, 16-31] + c_float_1p1 = _mm512_add_ps( selector2, c_float_1p1 ); + + // c[1,32-47] + c_float_1p2 = _mm512_add_ps( selector2, c_float_1p2 ); + + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector3, c_float_2p0 ); + + // c[2, 16-31] + c_float_2p1 = _mm512_add_ps( selector3, c_float_2p1 ); + + // c[2,32-47] + c_float_2p2 = _mm512_add_ps( selector3, c_float_2p2 ); + + // c[3,0-15] + c_float_3p0 = _mm512_add_ps( selector4, c_float_3p0 ); + + // c[3, 16-31] + c_float_3p1 = _mm512_add_ps( selector4, c_float_3p1 ); + + // c[3,32-47] + c_float_3p2 = _mm512_add_ps( selector4, c_float_3p2 ); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_4x48: + { + selector1 = _mm512_setzero_ps(); + + // c[0,0-15] + c_float_0p0 = _mm512_max_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_max_ps( selector1, c_float_0p1 ); + + // c[0,32-47] + c_float_0p2 = _mm512_max_ps( selector1, c_float_0p2 ); + + // c[1,0-15] + c_float_1p0 = _mm512_max_ps( selector1, c_float_1p0 ); + + // c[1,16-31] + c_float_1p1 = _mm512_max_ps( selector1, c_float_1p1 ); + + // c[1,32-47] + c_float_1p2 = _mm512_max_ps( selector1, c_float_1p2 ); + + // c[2,0-15] + c_float_2p0 = _mm512_max_ps( selector1, c_float_2p0 ); + + // c[2,16-31] + c_float_2p1 = _mm512_max_ps( selector1, c_float_2p1 ); + + // c[2,32-47] + c_float_2p2 = _mm512_max_ps( selector1, c_float_2p2 ); + + // c[3,0-15] + c_float_3p0 = _mm512_max_ps( selector1, c_float_3p0 ); + + // c[3,16-31] + c_float_3p1 = _mm512_max_ps( selector1, c_float_3p1 ); + + // c[3,32-47] + c_float_3p2 = _mm512_max_ps( selector1, c_float_3p2 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_4x48: + { + selector1 = _mm512_setzero_ps(); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_0p0) + + // c[0, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_0p1) + + // c[0, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_0p2) + + // c[1, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_1p0) + + // c[1, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_1p1) + + // c[1, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_1p2) + + // c[2, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_2p0) + + // c[2, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_2p1) + + // c[2, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_2p2) + + // c[3, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_3p0) + + // c[3, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_3p1) + + // c[3, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_3p2) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_DOWNSCALE_4x48: + { + // c[0, 0-15] + CVT_F32_BF16(c_float_0p0,0,0); + + // c[0, 16-31] + CVT_F32_BF16(c_float_0p1,0,1); + + // c[0, 32-47] + CVT_F32_BF16(c_float_0p2,0,2); + + // c[1, 0-15] + CVT_F32_BF16(c_float_1p0,1,0); + + // c[1, 16-31] + CVT_F32_BF16(c_float_1p1,1,1); + + // c[1, 32-47] + CVT_F32_BF16(c_float_1p2,1,2); + + // c[2, 0-15] + CVT_F32_BF16(c_float_2p0,2,0); + + // c[2, 16-31] + CVT_F32_BF16(c_float_2p1,2,1); + + // c[2, 32-47] + CVT_F32_BF16(c_float_2p2,2,2); + + // c[3, 0-15] + CVT_F32_BF16(c_float_3p0,3,0); + + // c[3, 16-31] + CVT_F32_BF16(c_float_3p1,3,1); + + // c[3, 32-47] + CVT_F32_BF16(c_float_3p2,3,2); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_4x48_DISABLE: + ; + + // Store the results. + // c[0,0-15] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 0*16 ), c_float_0p0 ); + + // c[0, 16-31] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 1*16 ), c_float_0p1 ); + + // c[0,32-47] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 2*16 ), c_float_0p2 ); + + // c[1,0-15] + _mm512_storeu_ps( c + ( rs_c * 1 ) + ( 0*16 ), c_float_1p0 ); + + // c[1,16-31] + _mm512_storeu_ps( c + ( rs_c * 1 ) + ( 1*16 ), c_float_1p1 ); + + // c[1,32-47] + _mm512_storeu_ps( c + ( rs_c * 1 ) + ( 2*16 ), c_float_1p2 ); + + // c[2,0-15] + _mm512_storeu_ps( c + ( rs_c * 2 ) + ( 0*16 ), c_float_2p0 ); + + // c[2,16-31] + _mm512_storeu_ps( c + ( rs_c * 2 ) + ( 1*16 ), c_float_2p1 ); + + // c[2,32-47] + _mm512_storeu_ps( c + ( rs_c * 2 ) + ( 2*16 ), c_float_2p2 ); + + // c[3,0-15] + _mm512_storeu_ps( c + ( rs_c * 3 ) + ( 0*16 ), c_float_3p0 ); + + // c[3,16-31] + _mm512_storeu_ps( c + ( rs_c * 3 ) + ( 1*16 ), c_float_3p1 ); + + // c[3,32-47] + _mm512_storeu_ps( c + ( rs_c * 3 ) + ( 2*16 ), c_float_3p2 ); +} + +// 3x48 bf16 kernel +LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_3x48) +{ + static void* post_ops_labels[] = + { + &&POST_OPS_3x48_DISABLE, + &&POST_OPS_BIAS_3x48, + &&POST_OPS_RELU_3x48, + &&POST_OPS_RELU_SCALE_3x48, + &&POST_OPS_DOWNSCALE_3x48 + }; + dim_t k_full_pieces = k0 / 2; + dim_t k_partial_pieces = k0 % 2; + + int32_t a_kfringe_buf = 0; + + // B matrix storage bfloat type + __m512bh b0; + __m512bh b1; + __m512bh b2; + + // A matrix storage bfloat type + __m512bh a_bf16_0; + + // Registers to use for accumulating C. + __m512 c_float_0p0 = _mm512_setzero_ps(); + __m512 c_float_0p1 = _mm512_setzero_ps(); + __m512 c_float_0p2 = _mm512_setzero_ps(); + + __m512 c_float_1p0 = _mm512_setzero_ps(); + __m512 c_float_1p1 = _mm512_setzero_ps(); + __m512 c_float_1p2 = _mm512_setzero_ps(); + + __m512 c_float_2p0 = _mm512_setzero_ps(); + __m512 c_float_2p1 = _mm512_setzero_ps(); + __m512 c_float_2p2 = _mm512_setzero_ps(); + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + b0 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + b1 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 1 ) ); + b2 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 2 ) ); + + // Broadcast a[0,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-47] = a[0,kr:kr+2]*b[kr:kr+2,0-47] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + c_float_0p1 = _mm512_dpbf16_ps( c_float_0p1, a_bf16_0, b1 ); + c_float_0p2 = _mm512_dpbf16_ps( c_float_0p2, a_bf16_0, b2 ); + + // Broadcast a[1,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-47] = a[1,kr:kr+2]*b[kr:kr+2,0-47] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_0, b0 ); + c_float_1p1 = _mm512_dpbf16_ps( c_float_1p1, a_bf16_0, b1 ); + c_float_1p2 = _mm512_dpbf16_ps( c_float_1p2, a_bf16_0, b2 ); + + // Broadcast a[2,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[2,0-47] = a[2,kr:kr+2]*b[kr:kr+2,0-47] + c_float_2p0 = _mm512_dpbf16_ps( c_float_2p0, a_bf16_0, b0 ); + c_float_2p1 = _mm512_dpbf16_ps( c_float_2p1, a_bf16_0, b1 ); + c_float_2p2 = _mm512_dpbf16_ps( c_float_2p2, a_bf16_0, b2 ); + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + b0 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + b1 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); + b2 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 2 ) ); + + // Broadcast a[0,kr:kr+2]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-47] = a[0,kr:kr+2]*b[kr:kr+2,0-47] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + c_float_0p1 = _mm512_dpbf16_ps( c_float_0p1, a_bf16_0, b1 ); + c_float_0p2 = _mm512_dpbf16_ps( c_float_0p2, a_bf16_0, b2 ); + + // Broadcast a[1,kr:kr+2]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-47] = a[1,kr:kr+2]*b[kr:kr+2,0-47] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_0, b0 ); + c_float_1p1 = _mm512_dpbf16_ps( c_float_1p1, a_bf16_0, b1 ); + c_float_1p2 = _mm512_dpbf16_ps( c_float_1p2, a_bf16_0, b2 ); + + // Broadcast a[2,kr:kr+2]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 2 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[2,0-47] = a[2,kr:kr+2]*b[kr:kr+2,0-47] + c_float_2p0 = _mm512_dpbf16_ps( c_float_2p0, a_bf16_0, b0 ); + c_float_2p1 = _mm512_dpbf16_ps( c_float_2p1, a_bf16_0, b1 ); + c_float_2p2 = _mm512_dpbf16_ps( c_float_2p2, a_bf16_0, b2 ); + } + + // Load alpha and beta + __m512 selector1 = _mm512_set1_ps( alpha ); + __m512 selector2 = _mm512_set1_ps( beta ); + + // Scale by alpha + c_float_0p0 = _mm512_mul_ps( selector1, c_float_0p0 ); + c_float_0p1 = _mm512_mul_ps( selector1, c_float_0p1 ); + c_float_0p2 = _mm512_mul_ps( selector1, c_float_0p2 ); + + c_float_1p0 = _mm512_mul_ps( selector1, c_float_1p0 ); + c_float_1p1 = _mm512_mul_ps( selector1, c_float_1p1 ); + c_float_1p2 = _mm512_mul_ps( selector1, c_float_1p2 ); + + c_float_2p0 = _mm512_mul_ps( selector1, c_float_2p0 ); + c_float_2p1 = _mm512_mul_ps( selector1, c_float_2p1 ); + c_float_2p2 = _mm512_mul_ps( selector1, c_float_2p2 ); + + // Scale C by beta. + if ( beta != 0 ) + { + // c[0,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * 0 ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + selector1 = _mm512_loadu_ps( c + ( rs_c * 0 ) + ( 1*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_0p1 = _mm512_add_ps( selector1, c_float_0p1 ); + + // c[0,32-47] + selector1 = _mm512_loadu_ps( c + ( rs_c * 0 ) + ( 2*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_0p2 = _mm512_add_ps( selector1, c_float_0p2 ); + + // c[1,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * 1 ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + + // c[1,16-31] + selector1 = _mm512_loadu_ps( c + ( rs_c * 1 ) + ( 1*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_1p1 = _mm512_add_ps( selector1, c_float_1p1 ); + + // c[1,32-47] + selector1 = _mm512_loadu_ps( c + ( rs_c * 1 ) + ( 2*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_1p2 = _mm512_add_ps( selector1, c_float_1p2 ); + + // c[2,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * 2 ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); + + // c[2,16-31] + selector1 = _mm512_loadu_ps( c + ( rs_c * 2 ) + ( 1*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_2p1 = _mm512_add_ps( selector1, c_float_2p1 ); + + // c[2,32-47] + selector1 = _mm512_loadu_ps( c + ( rs_c * 2 ) + ( 2*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_2p2 = _mm512_add_ps( selector1, c_float_2p2 ); + } + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_3x48: + { + __m512 selector3; + + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + selector1 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 1 * 16 ) ); + selector3 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 2 * 16 ) ); + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_add_ps( selector2, c_float_0p1 ); + + // c[0,32-47] + c_float_0p2 = _mm512_add_ps( selector3, c_float_0p2 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + + // c[1, 16-31] + c_float_1p1 = _mm512_add_ps( selector2, c_float_1p1 ); + + // c[1,32-47] + c_float_1p2 = _mm512_add_ps( selector3, c_float_1p2 ); + + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); + + // c[2, 16-31] + c_float_2p1 = _mm512_add_ps( selector2, c_float_2p1 ); + + // c[2,32-47] + c_float_2p2 = _mm512_add_ps( selector3, c_float_2p2 ); + } + else + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 1 ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 2 ) ); + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_add_ps( selector1, c_float_0p1 ); + + // c[0,32-47] + c_float_0p2 = _mm512_add_ps( selector1, c_float_0p2 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector2, c_float_1p0 ); + + // c[1, 16-31] + c_float_1p1 = _mm512_add_ps( selector2, c_float_1p1 ); + + // c[1,32-47] + c_float_1p2 = _mm512_add_ps( selector2, c_float_1p2 ); + + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector3, c_float_2p0 ); + + // c[2, 16-31] + c_float_2p1 = _mm512_add_ps( selector3, c_float_2p1 ); + + // c[2,32-47] + c_float_2p2 = _mm512_add_ps( selector3, c_float_2p2 ); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_3x48: + { + selector1 = _mm512_setzero_ps(); + + // c[0,0-15] + c_float_0p0 = _mm512_max_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_max_ps( selector1, c_float_0p1 ); + + // c[0,32-47] + c_float_0p2 = _mm512_max_ps( selector1, c_float_0p2 ); + + // c[1,0-15] + c_float_1p0 = _mm512_max_ps( selector1, c_float_1p0 ); + + // c[1,16-31] + c_float_1p1 = _mm512_max_ps( selector1, c_float_1p1 ); + + // c[1,32-47] + c_float_1p2 = _mm512_max_ps( selector1, c_float_1p2 ); + + // c[2,0-15] + c_float_2p0 = _mm512_max_ps( selector1, c_float_2p0 ); + + // c[2,16-31] + c_float_2p1 = _mm512_max_ps( selector1, c_float_2p1 ); + + // c[2,32-47] + c_float_2p2 = _mm512_max_ps( selector1, c_float_2p2 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_3x48: + { + selector1 = _mm512_setzero_ps(); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_0p0) + + // c[0, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_0p1) + + // c[0, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_0p2) + + // c[1, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_1p0) + + // c[1, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_1p1) + + // c[1, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_1p2) + + // c[2, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_2p0) + + // c[2, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_2p1) + + // c[2, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_2p2) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_DOWNSCALE_3x48: + { + // c[0, 0-15] + CVT_F32_BF16(c_float_0p0,0,0); + + // c[0, 16-31] + CVT_F32_BF16(c_float_0p1,0,1); + + // c[0, 32-47] + CVT_F32_BF16(c_float_0p2,0,2); + + // c[1, 0-15] + CVT_F32_BF16(c_float_1p0,1,0); + + // c[1, 16-31] + CVT_F32_BF16(c_float_1p1,1,1); + + // c[1, 32-47] + CVT_F32_BF16(c_float_1p2,1,2); + + // c[2, 0-15] + CVT_F32_BF16(c_float_2p0,2,0); + + // c[2, 16-31] + CVT_F32_BF16(c_float_2p1,2,1); + + // c[2, 32-47] + CVT_F32_BF16(c_float_2p2,2,2); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_3x48_DISABLE: + ; + + // Store the results. + // c[0,0-15] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 0*16 ), c_float_0p0 ); + + // c[0, 16-31] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 1*16 ), c_float_0p1 ); + + // c[0,32-47] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 2*16 ), c_float_0p2 ); + + // c[1,0-15] + _mm512_storeu_ps( c + ( rs_c * 1 ) + ( 0*16 ), c_float_1p0 ); + + // c[1,16-31] + _mm512_storeu_ps( c + ( rs_c * 1 ) + ( 1*16 ), c_float_1p1 ); + + // c[1,32-47] + _mm512_storeu_ps( c + ( rs_c * 1 ) + ( 2*16 ), c_float_1p2 ); + + // c[2,0-15] + _mm512_storeu_ps( c + ( rs_c * 2 ) + ( 0*16 ), c_float_2p0 ); + + // c[2,16-31] + _mm512_storeu_ps( c + ( rs_c * 2 ) + ( 1*16 ), c_float_2p1 ); + + // c[2,32-47] + _mm512_storeu_ps( c + ( rs_c * 2 ) + ( 2*16 ), c_float_2p2 ); +} + +// 2x48 bf16 kernel +LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_2x48) +{ + static void* post_ops_labels[] = + { + &&POST_OPS_2x48_DISABLE, + &&POST_OPS_BIAS_2x48, + &&POST_OPS_RELU_2x48, + &&POST_OPS_RELU_SCALE_2x48, + &&POST_OPS_DOWNSCALE_2x48 + }; + dim_t k_full_pieces = k0 / 2; + dim_t k_partial_pieces = k0 % 2; + + int32_t a_kfringe_buf = 0; + + // B matrix storage bfloat type + __m512bh b0; + __m512bh b1; + __m512bh b2; + + // A matrix storage bfloat type + __m512bh a_bf16_0; + + // Registers to use for accumulating C. + __m512 c_float_0p0 = _mm512_setzero_ps(); + __m512 c_float_0p1 = _mm512_setzero_ps(); + __m512 c_float_0p2 = _mm512_setzero_ps(); + + __m512 c_float_1p0 = _mm512_setzero_ps(); + __m512 c_float_1p1 = _mm512_setzero_ps(); + __m512 c_float_1p2 = _mm512_setzero_ps(); + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + b0 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + b1 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 1 ) ); + b2 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 2 ) ); + + // Broadcast a[0,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-47] = a[0,kr:kr+2]*b[kr:kr+2,0-47] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + c_float_0p1 = _mm512_dpbf16_ps( c_float_0p1, a_bf16_0, b1 ); + c_float_0p2 = _mm512_dpbf16_ps( c_float_0p2, a_bf16_0, b2 ); + + // Broadcast a[1,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-47] = a[1,kr:kr+2]*b[kr:kr+2,0-47] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_0, b0 ); + c_float_1p1 = _mm512_dpbf16_ps( c_float_1p1, a_bf16_0, b1 ); + c_float_1p2 = _mm512_dpbf16_ps( c_float_1p2, a_bf16_0, b2 ); + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + b0 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + b1 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); + b2 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 2 ) ); + + // Broadcast a[0,kr:kr+2]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-47] = a[0,kr:kr+2]*b[kr:kr+2,0-47] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + c_float_0p1 = _mm512_dpbf16_ps( c_float_0p1, a_bf16_0, b1 ); + c_float_0p2 = _mm512_dpbf16_ps( c_float_0p2, a_bf16_0, b2 ); + + // Broadcast a[1,kr:kr+2]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-47] = a[1,kr:kr+2]*b[kr:kr+2,0-47] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_0, b0 ); + c_float_1p1 = _mm512_dpbf16_ps( c_float_1p1, a_bf16_0, b1 ); + c_float_1p2 = _mm512_dpbf16_ps( c_float_1p2, a_bf16_0, b2 ); + } + + // Load alpha and beta + __m512 selector1 = _mm512_set1_ps( alpha ); + __m512 selector2 = _mm512_set1_ps( beta ); + + // Scale by alpha + c_float_0p0 = _mm512_mul_ps( selector1, c_float_0p0 ); + c_float_0p1 = _mm512_mul_ps( selector1, c_float_0p1 ); + c_float_0p2 = _mm512_mul_ps( selector1, c_float_0p2 ); + + c_float_1p0 = _mm512_mul_ps( selector1, c_float_1p0 ); + c_float_1p1 = _mm512_mul_ps( selector1, c_float_1p1 ); + c_float_1p2 = _mm512_mul_ps( selector1, c_float_1p2 ); + + // Scale C by beta. + if ( beta != 0 ) + { + // c[0,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * 0 ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + selector1 = _mm512_loadu_ps( c + ( rs_c * 0 ) + ( 1*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_0p1 = _mm512_add_ps( selector1, c_float_0p1 ); + + // c[0,32-47] + selector1 = _mm512_loadu_ps( c + ( rs_c * 0 ) + ( 2*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_0p2 = _mm512_add_ps( selector1, c_float_0p2 ); + + // c[1,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * 1 ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + + // c[1,16-31] + selector1 = _mm512_loadu_ps( c + ( rs_c * 1 ) + ( 1*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_1p1 = _mm512_add_ps( selector1, c_float_1p1 ); + + // c[1,32-47] + selector1 = _mm512_loadu_ps( c + ( rs_c * 1 ) + ( 2*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_1p2 = _mm512_add_ps( selector1, c_float_1p2 ); + } + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_2x48: + { + __m512 selector3; + + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + selector1 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 1 * 16 ) ); + selector3 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 2 * 16 ) ); + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_add_ps( selector2, c_float_0p1 ); + + // c[0,32-47] + c_float_0p2 = _mm512_add_ps( selector3, c_float_0p2 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + + // c[1, 16-31] + c_float_1p1 = _mm512_add_ps( selector2, c_float_1p1 ); + + // c[1,32-47] + c_float_1p2 = _mm512_add_ps( selector3, c_float_1p2 ); + } + else + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 1 ) ); + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_add_ps( selector1, c_float_0p1 ); + + // c[0,32-47] + c_float_0p2 = _mm512_add_ps( selector1, c_float_0p2 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector2, c_float_1p0 ); + + // c[1, 16-31] + c_float_1p1 = _mm512_add_ps( selector2, c_float_1p1 ); + + // c[1,32-47] + c_float_1p2 = _mm512_add_ps( selector2, c_float_1p2 ); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_2x48: + { + selector1 = _mm512_setzero_ps(); + + // c[0,0-15] + c_float_0p0 = _mm512_max_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_max_ps( selector1, c_float_0p1 ); + + // c[0,32-47] + c_float_0p2 = _mm512_max_ps( selector1, c_float_0p2 ); + + // c[1,0-15] + c_float_1p0 = _mm512_max_ps( selector1, c_float_1p0 ); + + // c[1,16-31] + c_float_1p1 = _mm512_max_ps( selector1, c_float_1p1 ); + + // c[1,32-47] + c_float_1p2 = _mm512_max_ps( selector1, c_float_1p2 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_2x48: + { + selector1 = _mm512_setzero_ps(); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_0p0) + + // c[0, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_0p1) + + // c[0, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_0p2) + + // c[1, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_1p0) + + // c[1, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_1p1) + + // c[1, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_1p2) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_DOWNSCALE_2x48: + { + // c[0, 0-15] + CVT_F32_BF16(c_float_0p0,0,0); + + // c[0, 16-31] + CVT_F32_BF16(c_float_0p1,0,1); + + // c[0, 32-47] + CVT_F32_BF16(c_float_0p2,0,2); + + // c[1, 0-15] + CVT_F32_BF16(c_float_1p0,1,0); + + // c[1, 16-31] + CVT_F32_BF16(c_float_1p1,1,1); + + // c[1, 32-47] + CVT_F32_BF16(c_float_1p2,1,2); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_2x48_DISABLE: + ; + + // Store the results. + // c[0,0-15] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 0*16 ), c_float_0p0 ); + + // c[0, 16-31] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 1*16 ), c_float_0p1 ); + + // c[0,32-47] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 2*16 ), c_float_0p2 ); + + // c[1,0-15] + _mm512_storeu_ps( c + ( rs_c * 1 ) + ( 0*16 ), c_float_1p0 ); + + // c[1,16-31] + _mm512_storeu_ps( c + ( rs_c * 1 ) + ( 1*16 ), c_float_1p1 ); + + // c[1,32-47] + _mm512_storeu_ps( c + ( rs_c * 1 ) + ( 2*16 ), c_float_1p2 ); +} + +// 1x48 bf16 kernel +LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_1x48) +{ + static void* post_ops_labels[] = + { + &&POST_OPS_1x48_DISABLE, + &&POST_OPS_BIAS_1x48, + &&POST_OPS_RELU_1x48, + &&POST_OPS_RELU_SCALE_1x48, + &&POST_OPS_DOWNSCALE_1x48 + }; + dim_t k_full_pieces = k0 / 2; + dim_t k_partial_pieces = k0 % 2; + + int32_t a_kfringe_buf = 0; + + // B matrix storage bfloat type + __m512bh b0; + __m512bh b1; + __m512bh b2; + + // A matrix storage bfloat type + __m512bh a_bf16_0; + + // Registers to use for accumulating C. + __m512 c_float_0p0 = _mm512_setzero_ps(); + __m512 c_float_0p1 = _mm512_setzero_ps(); + __m512 c_float_0p2 = _mm512_setzero_ps(); + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + b0 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + b1 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 1 ) ); + b2 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 2 ) ); + + // Broadcast a[0,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-47] = a[0,kr:kr+2]*b[kr:kr+2,0-47] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + c_float_0p1 = _mm512_dpbf16_ps( c_float_0p1, a_bf16_0, b1 ); + c_float_0p2 = _mm512_dpbf16_ps( c_float_0p2, a_bf16_0, b2 ); + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + b0 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + b1 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); + b2 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 2 ) ); + + // Broadcast a[0,kr:kr+2]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-47] = a[0,kr:kr+2]*b[kr:kr+2,0-47] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + c_float_0p1 = _mm512_dpbf16_ps( c_float_0p1, a_bf16_0, b1 ); + c_float_0p2 = _mm512_dpbf16_ps( c_float_0p2, a_bf16_0, b2 ); + } + + // Load alpha and beta + __m512 selector1 = _mm512_set1_ps( alpha ); + __m512 selector2 = _mm512_set1_ps( beta ); + + // Scale by alpha + c_float_0p0 = _mm512_mul_ps( selector1, c_float_0p0 ); + c_float_0p1 = _mm512_mul_ps( selector1, c_float_0p1 ); + c_float_0p2 = _mm512_mul_ps( selector1, c_float_0p2 ); + + // Scale C by beta. + if ( beta != 0 ) + { + // c[0,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * 0 ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + selector1 = _mm512_loadu_ps( c + ( rs_c * 0 ) + ( 1*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_0p1 = _mm512_add_ps( selector1, c_float_0p1 ); + + // c[0,32-47] + selector1 = _mm512_loadu_ps( c + ( rs_c * 0 ) + ( 2*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_0p2 = _mm512_add_ps( selector1, c_float_0p2 ); + } + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_1x48: + { + __m512 selector3; + + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + selector1 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 1 * 16 ) ); + selector3 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 2 * 16 ) ); + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_add_ps( selector2, c_float_0p1 ); + + // c[0,32-47] + c_float_0p2 = _mm512_add_ps( selector3, c_float_0p2 ); + } + else + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 0 ) ); + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_add_ps( selector1, c_float_0p1 ); + + // c[0,32-47] + c_float_0p2 = _mm512_add_ps( selector1, c_float_0p2 ); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_1x48: + { + selector1 = _mm512_setzero_ps(); + + // c[0,0-15] + c_float_0p0 = _mm512_max_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_max_ps( selector1, c_float_0p1 ); + + // c[0,32-47] + c_float_0p2 = _mm512_max_ps( selector1, c_float_0p2 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_1x48: + { + selector1 = _mm512_setzero_ps(); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_0p0) + + // c[0, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_0p1) + + // c[0, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_0p2) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_DOWNSCALE_1x48: + { + // c[0, 0-15] + CVT_F32_BF16(c_float_0p0,0,0); + + // c[0, 16-31] + CVT_F32_BF16(c_float_0p1,0,1); + + // c[0, 32-47] + CVT_F32_BF16(c_float_0p2,0,2); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_1x48_DISABLE: + ; + + // Store the results. + // c[0,0-15] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 0*16 ), c_float_0p0 ); + + // c[0, 16-31] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 1*16 ), c_float_0p1 ); + + // c[0,32-47] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 2*16 ), c_float_0p2 ); +} +#endif diff --git a/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_n_fringe_bf16_amd512vnni.c b/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_n_fringe_bf16_amd512vnni.c new file mode 100644 index 0000000000..1a37ab071a --- /dev/null +++ b/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_n_fringe_bf16_amd512vnni.c @@ -0,0 +1,2502 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ +#include +#include + +#include "blis.h" +#include "lpgemm_kernels.h" +#include "lpgemm_f32_kern_macros.h" + +#ifdef BLIS_KERNELS_ZEN4 +// 6xlt16 bf16 fringe kernel +LPGEMM_N_LT_NR0_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6xlt16) +{ + static void* post_ops_labels[] = + { + &&POST_OPS_6xLT16_DISABLE, + &&POST_OPS_BIAS_6xLT16, + &&POST_OPS_RELU_6xLT16, + &&POST_OPS_RELU_SCALE_6xLT16, + &&POST_OPS_DOWNSCALE_6xLT16 + }; + dim_t MR = 6; + dim_t m_full_pieces = m0 / MR; + dim_t m_full_pieces_loop_limit = m_full_pieces * MR; + dim_t m_partial_pieces = m0 % MR; + + dim_t k_full_pieces = k0 / 2; + dim_t k_partial_pieces = k0 % 2; + + int32_t a_kfringe_buf = 0; + + // B matrix storage bfloat type + __m512bh b0; + + // A matrix storage bfloat type + __m512bh a_bf16_0; + + // For corner cases. + float buf0[16]; + float buf1[16]; + float buf2[16]; + float buf3[16]; + float buf4[16]; + float buf5[16]; + + for ( dim_t ir = 0; ir < m_full_pieces_loop_limit; ir += MR ) + { + // Registers to use for accumulating C. + __m512 c_float_0p0 = _mm512_setzero_ps(); + + __m512 c_float_1p0 = _mm512_setzero_ps(); + + __m512 c_float_2p0 = _mm512_setzero_ps(); + + __m512 c_float_3p0 = _mm512_setzero_ps(); + + __m512 c_float_4p0 = _mm512_setzero_ps(); + + __m512 c_float_5p0 = _mm512_setzero_ps(); + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + // Load 2 rows with 16 extended elements each from B to 1 ZMM + // registers. It is to be noted that the B matrix is packed for use + // in bf16 instructions and each load to ZMM register will have 2 + // elements along k direction and 16 elements across n directions, + // so 2x16 elements to a ZMM register. + b0 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-15] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + + // Broadcast a[1,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-15] = a[1,kr:kr+2]*b[kr:kr+2,0-15] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_0, b0 ); + + // Broadcast a[2,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[2,0-15] = a[2,kr:kr+2]*b[kr:kr+2,0-15] + c_float_2p0 = _mm512_dpbf16_ps( c_float_2p0, a_bf16_0, b0 ); + + // Broadcast a[3,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 3 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[3,0-15] = a[3,kr:kr+2]*b[kr:kr+2,0-15] + c_float_3p0 = _mm512_dpbf16_ps( c_float_3p0, a_bf16_0, b0 ); + + // Broadcast a[4,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 4 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[4,0-15] = a[4,kr:kr+2]*b[kr:kr+2,0-15] + c_float_4p0 = _mm512_dpbf16_ps( c_float_4p0, a_bf16_0, b0 ); + + // Broadcast a[5,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 5 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[5,0-15] = a[5,kr:kr+2]*b[kr:kr+2,0-15] + c_float_5p0 = _mm512_dpbf16_ps( c_float_5p0, a_bf16_0, b0 ); + } + + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + b0 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+2]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( bfloat16 ) ) + ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-15] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + + // Broadcast a[1,kr:kr+2]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( bfloat16 ) ) + ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-15] = a[1,kr:kr+2]*b[kr:kr+2,0-15] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_0, b0 ); + + // Broadcast a[2,kr:kr+2]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 2 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( bfloat16 ) ) + ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[2,0-15] = a[2,kr:kr+2]*b[kr:kr+2,0-15] + c_float_2p0 = _mm512_dpbf16_ps( c_float_2p0, a_bf16_0, b0 ); + + // Broadcast a[3,kr:kr+2]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 3 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( bfloat16 ) ) + ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[3,0-15] = a[3,kr:kr+2]*b[kr:kr+2,0-15] + c_float_3p0 = _mm512_dpbf16_ps( c_float_3p0, a_bf16_0, b0 ); + + // Broadcast a[4,kr:kr+2]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 4 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( bfloat16 ) ) + ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[4,0-15] = a[4,kr:kr+2]*b[kr:kr+2,0-15] + c_float_4p0 = _mm512_dpbf16_ps( c_float_4p0, a_bf16_0, b0 ); + + // Broadcast a[5,kr:kr+2]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 5 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( bfloat16 ) ) + ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[5,0-15] = a[5,kr:kr+2]*b[kr:kr+2,0-15] + c_float_5p0 = _mm512_dpbf16_ps( c_float_5p0, a_bf16_0, b0 ); + } + + // Load alpha and beta + __m512 selector1 = _mm512_set1_ps( alpha ); + __m512 selector2 = _mm512_set1_ps( beta ); + + // Scale by alpha + c_float_0p0 = _mm512_mul_ps( selector1, c_float_0p0 ); + + c_float_1p0 = _mm512_mul_ps( selector1, c_float_1p0 ); + + c_float_2p0 = _mm512_mul_ps( selector1, c_float_2p0 ); + + c_float_3p0 = _mm512_mul_ps( selector1, c_float_3p0 ); + + c_float_4p0 = _mm512_mul_ps( selector1, c_float_4p0 ); + + c_float_5p0 = _mm512_mul_ps( selector1, c_float_5p0 ); + + // Scale C by beta. + if ( beta != 0 ) + { + memcpy( buf0, ( c + ( rs_c * ( ir + 0 ) ) ), ( n0_rem * sizeof( float ) ) ); + memcpy( buf1, ( c + ( rs_c * ( ir + 1 ) ) ), ( n0_rem * sizeof( float ) ) ); + memcpy( buf2, ( c + ( rs_c * ( ir + 2 ) ) ), ( n0_rem * sizeof( float ) ) ); + memcpy( buf3, ( c + ( rs_c * ( ir + 3 ) ) ), ( n0_rem * sizeof( float) ) ); + memcpy( buf4, ( c + ( rs_c * ( ir + 4 ) ) ), ( n0_rem * sizeof( float ) ) ); + memcpy( buf5, ( c + ( rs_c * ( ir + 5 ) ) ), ( n0_rem * sizeof( float ) ) ); + + // c[0,0-15] + selector1 = _mm512_loadu_ps( buf0 ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[1,0-15] + selector1 = _mm512_loadu_ps( buf1 ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + + // c[2,0-15] + selector1 = _mm512_loadu_ps( buf2 ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); + + // c[3,0-15] + selector1 = _mm512_loadu_ps( buf3 ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_3p0 = _mm512_add_ps( selector1, c_float_3p0 ); + + // c[4,0-15] + selector1 = _mm512_loadu_ps( buf4 ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_4p0 = _mm512_add_ps( selector1, c_float_4p0 ); + + // c[5,0-15] + selector1 = _mm512_loadu_ps( buf5 ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_5p0 = _mm512_add_ps( selector1, c_float_5p0 ); + } + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_6xLT16: + { + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + memcpy( buf0, ( ( float* )post_ops_list_temp->op_args1 + + post_op_c_j ), ( n0_rem * sizeof( float ) ) ); + selector1 = _mm512_loadu_ps( buf0 ); + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); + + // c[3,0-15] + c_float_3p0 = _mm512_add_ps( selector1, c_float_3p0 ); + + // c[4,0-15] + c_float_4p0 = _mm512_add_ps( selector1, c_float_4p0 ); + + // c[5,0-15] + c_float_5p0 = _mm512_add_ps( selector1, c_float_5p0 ); + } + else + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 1 ) ); + __m512 selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 2 ) ); + __m512 selector4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 3 ) ); + __m512 selector5 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 4 ) ); + __m512 selector6 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 5 ) ); + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector2, c_float_1p0 ); + + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector3, c_float_2p0 ); + + // c[3,0-15] + c_float_3p0 = _mm512_add_ps( selector4, c_float_3p0 ); + + // c[4,0-15] + c_float_4p0 = _mm512_add_ps( selector5, c_float_4p0 ); + + // c[5,0-15] + c_float_5p0 = _mm512_add_ps( selector6, c_float_5p0 ); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_6xLT16: + { + selector1 = _mm512_setzero_ps(); + + // c[0,0-15] + c_float_0p0 = _mm512_max_ps( selector1, c_float_0p0 ); + + // c[1,0-15] + c_float_1p0 = _mm512_max_ps( selector1, c_float_1p0 ); + + // c[2,0-15] + c_float_2p0 = _mm512_max_ps( selector1, c_float_2p0 ); + + // c[3,0-15] + c_float_3p0 = _mm512_max_ps( selector1, c_float_3p0 ); + + // c[4,0-15] + c_float_4p0 = _mm512_max_ps( selector1, c_float_4p0 ); + + // c[5,0-15] + c_float_5p0 = _mm512_max_ps( selector1, c_float_5p0 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_6xLT16: + { + selector1 = _mm512_setzero_ps(); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_0p0) + + // c[1, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_1p0) + + // c[2, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_2p0) + + // c[3, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_3p0) + + // c[4, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_4p0) + + // c[5, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_5p0) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_DOWNSCALE_6xLT16: + { + // c[0, 0-15] + CVT_F32_BF16_LT16(c_float_0p0,0,0); + + // c[1, 0-15] + CVT_F32_BF16_LT16(c_float_1p0,1,0); + + // c[2, 0-15] + CVT_F32_BF16_LT16(c_float_2p0,2,0); + + // c[3, 0-15] + CVT_F32_BF16_LT16(c_float_3p0,3,0); + + // c[4, 0-15] + CVT_F32_BF16_LT16(c_float_4p0,4,0); + + // c[5, 0-15] + CVT_F32_BF16_LT16(c_float_5p0,5,0); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_6xLT16_DISABLE: + ; + + // Store the results. + // c[0,0-15] + _mm512_storeu_ps( buf0, c_float_0p0 ); + + // c[1,0-15] + _mm512_storeu_ps( buf1, c_float_1p0 ); + + // c[2,0-15] + _mm512_storeu_ps( buf2, c_float_2p0 ); + + // c[3,0-15] + _mm512_storeu_ps( buf3, c_float_3p0 ); + + // c[4,0-15] + _mm512_storeu_ps( buf4, c_float_4p0 ); + + // c[5,0-15] + _mm512_storeu_ps( buf5, c_float_5p0 ); + + // Memcpy partial parts. + // c[0,0-15] + memcpy( c + ( rs_c * ( ir + 0 ) ) + ( 0*16 ), buf0, ( n0_rem * sizeof( float ) ) ); + + // c[1,0-15] + memcpy( c + ( rs_c * ( ir + 1 ) ) + ( 0*16 ), buf1, ( n0_rem * sizeof( float ) ) ); + + // c[2,0-15] + memcpy( c + ( rs_c * ( ir + 2 ) ) + ( 0*16 ), buf2, ( n0_rem * sizeof( float ) ) ); + + // c[3,0-15] + memcpy( c + ( rs_c * ( ir + 3 ) ) + ( 0*16 ), buf3, ( n0_rem * sizeof( float ) ) ); + + // c[4,0-15] + memcpy( c + ( rs_c * ( ir + 4 ) ) + ( 0*16 ), buf4, ( n0_rem * sizeof( float ) ) ); + + // c[5,0-15] + memcpy( c + ( rs_c * ( ir + 5 ) ) + ( 0*16 ), buf5, ( n0_rem * sizeof( float ) ) ); + + a = a + ( MR * ps_a ); + post_op_c_i += MR; + } + + if ( m_partial_pieces > 0 ) + { + if ( m_partial_pieces == 5 ) + { + int cs_a_use = ( cs_a == 2 ) ? 2 : ( ( cs_a / 6 ) * 5 ); + lpgemm_rowvar_bf16bf16f32of32_5xlt16 + ( + k0, + a, rs_a, cs_a_use, + b, rs_b, cs_b, + ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, + alpha, beta, n0_rem, + is_last_k, + post_op_c_i, post_op_c_j, + post_ops_list, rs_c_downscale + ); + } + else if ( m_partial_pieces == 4 ) + { + int cs_a_use = ( cs_a == 2 ) ? 2 : ( ( cs_a / 6 ) * 4 ); + lpgemm_rowvar_bf16bf16f32of32_4xlt16 + ( + k0, + a, rs_a, cs_a_use, + b, rs_b, cs_b, + ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, + alpha, beta, n0_rem, + is_last_k, + post_op_c_i, post_op_c_j, + post_ops_list, rs_c_downscale + ); + } + else if ( m_partial_pieces == 3 ) + { + int cs_a_use = ( cs_a == 2 ) ? 2 : ( ( cs_a / 6 ) * 3 ); + lpgemm_rowvar_bf16bf16f32of32_3xlt16 + ( + k0, + a, rs_a, cs_a_use, + b, rs_b, cs_b, + ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, + alpha, beta, n0_rem, + is_last_k, + post_op_c_i, post_op_c_j, + post_ops_list, rs_c_downscale + ); + } + else if ( m_partial_pieces == 2 ) + { + int cs_a_use = ( cs_a == 2 ) ? 2 : ( ( cs_a / 6 ) * 2 ); + lpgemm_rowvar_bf16bf16f32of32_2xlt16 + ( + k0, + a, rs_a, cs_a_use, + b, rs_b, cs_b, + ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, + alpha, beta, n0_rem, + is_last_k, + post_op_c_i, post_op_c_j, + post_ops_list, rs_c_downscale + ); + } + else if ( m_partial_pieces == 1 ) + { + int cs_a_use = ( cs_a == 2 ) ? 2 : ( ( cs_a / 6 ) * 1 ); + lpgemm_rowvar_bf16bf16f32of32_1xlt16 + ( + k0, + a, rs_a, cs_a_use, + b, rs_b, cs_b, + ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, + alpha, beta, n0_rem, + is_last_k, + post_op_c_i, post_op_c_j, + post_ops_list, rs_c_downscale + ); + } + } +} + +// 6x16 bf16 fringe kernel +LPGEMM_N_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x16) +{ + static void* post_ops_labels[] = + { + &&POST_OPS_6x16_DISABLE, + &&POST_OPS_BIAS_6x16, + &&POST_OPS_RELU_6x16, + &&POST_OPS_RELU_SCALE_6x16, + &&POST_OPS_DOWNSCALE_6x16 + }; + dim_t MR = 6; + dim_t m_full_pieces = m0 / MR; + dim_t m_full_pieces_loop_limit = m_full_pieces * MR; + dim_t m_partial_pieces = m0 % MR; + + dim_t k_full_pieces = k0 / 2; + dim_t k_partial_pieces = k0 % 2; + + int32_t a_kfringe_buf = 0; + + // B matrix storage bfloat type + __m512bh b0; + + // A matrix storage bfloat type + __m512bh a_bf16_0; + + for ( dim_t ir = 0; ir < m_full_pieces_loop_limit; ir += MR ) + { + // Registers to use for accumulating C. + __m512 c_float_0p0 = _mm512_setzero_ps(); + + __m512 c_float_1p0 = _mm512_setzero_ps(); + + __m512 c_float_2p0 = _mm512_setzero_ps(); + + __m512 c_float_3p0 = _mm512_setzero_ps(); + + __m512 c_float_4p0 = _mm512_setzero_ps(); + + __m512 c_float_5p0 = _mm512_setzero_ps(); + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + // Load 2 rows with 16 elements each from B to 1 ZMM registers. It + // is to be noted that the B matrix is packed for use in bf16 + // instructions and each load to ZMM register will have 2 elements + // along k direction and 16 elements across n directions, so 2x16 + // elements to a ZMM register. + b0 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-15] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + + // Broadcast a[1,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-15] = a[1,kr:kr+2]*b[kr:kr+2,0-15] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_0, b0 ); + + // Broadcast a[2,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[2,0-15] = a[2,kr:kr+2]*b[kr:kr+2,0-15] + c_float_2p0 = _mm512_dpbf16_ps( c_float_2p0, a_bf16_0, b0 ); + + // Broadcast a[3,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 3 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[3,0-15] = a[3,kr:kr+2]*b[kr:kr+2,0-15] + c_float_3p0 = _mm512_dpbf16_ps( c_float_3p0, a_bf16_0, b0 ); + + // Broadcast a[4,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 4 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[4,0-15] = a[4,kr:kr+2]*b[kr:kr+2,0-15] + c_float_4p0 = _mm512_dpbf16_ps( c_float_4p0, a_bf16_0, b0 ); + + // Broadcast a[5,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 5 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[5,0-15] = a[5,kr:kr+2]*b[kr:kr+2,0-15] + c_float_5p0 = _mm512_dpbf16_ps( c_float_5p0, a_bf16_0, b0 ); + } + // Handle k remainder. + + if ( k_partial_pieces > 0 ) + { + b0 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+2]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( bfloat16 ) ) + ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-15] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + + // Broadcast a[1,kr:kr+2]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( bfloat16 ) ) + ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-15] = a[1,kr:kr+2]*b[kr:kr+2,0-15] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_0, b0 ); + + // Broadcast a[2,kr:kr+2]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 2 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( bfloat16 ) ) + ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[2,0-15] = a[2,kr:kr+2]*b[kr:kr+2,0-15] + c_float_2p0 = _mm512_dpbf16_ps( c_float_2p0, a_bf16_0, b0 ); + + // Broadcast a[3,kr:kr+2]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 3 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( bfloat16 ) ) + ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[3,0-15] = a[3,kr:kr+2]*b[kr:kr+2,0-15] + c_float_3p0 = _mm512_dpbf16_ps( c_float_3p0, a_bf16_0, b0 ); + + // Broadcast a[4,kr:kr+2]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 4 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( bfloat16 ) ) + ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[4,0-15] = a[4,kr:kr+2]*b[kr:kr+2,0-15] + c_float_4p0 = _mm512_dpbf16_ps( c_float_4p0, a_bf16_0, b0 ); + + // Broadcast a[5,kr:kr+2]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 5 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( bfloat16 ) ) + ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[5,0-15] = a[5,kr:kr+2]*b[kr:kr+2,0-15] + c_float_5p0 = _mm512_dpbf16_ps( c_float_5p0, a_bf16_0, b0 ); + } + + // Load alpha and beta + __m512 selector1 = _mm512_set1_ps( alpha ); + __m512 selector2 = _mm512_set1_ps( beta ); + + // Scale by alpha + c_float_0p0 = _mm512_mul_ps( selector1, c_float_0p0 ); + + c_float_1p0 = _mm512_mul_ps( selector1, c_float_1p0 ); + + c_float_2p0 = _mm512_mul_ps( selector1, c_float_2p0 ); + + c_float_3p0 = _mm512_mul_ps( selector1, c_float_3p0 ); + + c_float_4p0 = _mm512_mul_ps( selector1, c_float_4p0 ); + + c_float_5p0 = _mm512_mul_ps( selector1, c_float_5p0 ); + + // Scale C by beta. + if ( beta != 0 ) + { + // c[0,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 0 ) ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[1,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 1 ) ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + + // c[2,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 2 ) ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); + + // c[3,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 3 ) ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_3p0 = _mm512_add_ps( selector1, c_float_3p0 ); + + // c[4,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 4 ) ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_4p0 = _mm512_add_ps( selector1, c_float_4p0 ); + + // c[5,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 5 ) ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_5p0 = _mm512_add_ps( selector1, c_float_5p0 ); + } + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_6x16: + { + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + selector1 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_op_c_j ); + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); + + // c[3,0-15] + c_float_3p0 = _mm512_add_ps( selector1, c_float_3p0 ); + + // c[4,0-15] + c_float_4p0 = _mm512_add_ps( selector1, c_float_4p0 ); + + // c[5,0-15] + c_float_5p0 = _mm512_add_ps( selector1, c_float_5p0 ); + } + else + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 1 ) ); + __m512 selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 2 ) ); + __m512 selector4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 3 ) ); + __m512 selector5 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 4 ) ); + __m512 selector6 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 5 ) ); + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector2, c_float_1p0 ); + + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector3, c_float_2p0 ); + + // c[3,0-15] + c_float_3p0 = _mm512_add_ps( selector4, c_float_3p0 ); + + // c[4,0-15] + c_float_4p0 = _mm512_add_ps( selector5, c_float_4p0 ); + + // c[5,0-15] + c_float_5p0 = _mm512_add_ps( selector6, c_float_5p0 ); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_6x16: + { + selector1 = _mm512_setzero_ps(); + + // c[0,0-15] + c_float_0p0 = _mm512_max_ps( selector1, c_float_0p0 ); + + // c[1,0-15] + c_float_1p0 = _mm512_max_ps( selector1, c_float_1p0 ); + + // c[2,0-15] + c_float_2p0 = _mm512_max_ps( selector1, c_float_2p0 ); + + // c[3,0-15] + c_float_3p0 = _mm512_max_ps( selector1, c_float_3p0 ); + + // c[4,0-15] + c_float_4p0 = _mm512_max_ps( selector1, c_float_4p0 ); + + // c[5,0-15] + c_float_5p0 = _mm512_max_ps( selector1, c_float_5p0 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_6x16: + { + selector1 = _mm512_setzero_ps(); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_0p0) + + // c[1, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_1p0) + + // c[2, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_2p0) + + // c[3, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_3p0) + + // c[4, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_4p0) + + // c[5, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_5p0) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_DOWNSCALE_6x16: + { + // c[0, 0-15] + CVT_F32_BF16(c_float_0p0,0,0); + + // c[1, 0-15] + CVT_F32_BF16(c_float_1p0,1,0); + + // c[2, 0-15] + CVT_F32_BF16(c_float_2p0,2,0); + + // c[3, 0-15] + CVT_F32_BF16(c_float_3p0,3,0); + + // c[4, 0-15] + CVT_F32_BF16(c_float_4p0,4,0); + + // c[5, 0-15] + CVT_F32_BF16(c_float_5p0,5,0); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_6x16_DISABLE: + ; + + // Store the results. + // c[0,0-15] + _mm512_storeu_ps( c + ( rs_c * ( ir + 0 ) ) + ( 0*16 ), c_float_0p0 ); + + // c[1,0-15] + _mm512_storeu_ps( c + ( rs_c * ( ir + 1 ) ) + ( 0*16 ), c_float_1p0 ); + + // c[2,0-15] + _mm512_storeu_ps( c + ( rs_c * ( ir + 2 ) ) + ( 0*16 ), c_float_2p0 ); + + // c[3,0-15] + _mm512_storeu_ps( c + ( rs_c * ( ir + 3 ) ) + ( 0*16 ), c_float_3p0 ); + + // c[4,0-15] + _mm512_storeu_ps( c + ( rs_c * ( ir + 4 ) ) + ( 0*16 ), c_float_4p0 ); + + // c[5,0-15] + _mm512_storeu_ps( c + ( rs_c * ( ir + 5 ) ) + ( 0*16 ), c_float_5p0 ); + + a = a + ( MR * ps_a ); + post_op_c_i += MR; + } + + if ( m_partial_pieces > 0 ) + { + if ( m_partial_pieces == 5 ) + { + int cs_a_use = ( cs_a == 2 ) ? 2 : ( ( cs_a / 6 ) * 5 ); + lpgemm_rowvar_bf16bf16f32of32_5x16 + ( + k0, + a, rs_a, cs_a_use, + b, rs_b, cs_b, + ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, + alpha, beta, + is_last_k, + post_op_c_i, post_op_c_j, + post_ops_list, rs_c_downscale + ); + } + else if ( m_partial_pieces == 4 ) + { + int cs_a_use = ( cs_a == 2 ) ? 2 : ( ( cs_a / 6 ) * 4 ); + lpgemm_rowvar_bf16bf16f32of32_4x16 + ( + k0, + a, rs_a, cs_a_use, + b, rs_b, cs_b, + ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, + alpha, beta, + is_last_k, + post_op_c_i, post_op_c_j, + post_ops_list, rs_c_downscale + ); + } + else if ( m_partial_pieces == 3 ) + { + int cs_a_use = ( cs_a == 2) ? 2 : ( ( cs_a / 6 ) * 3 ); + lpgemm_rowvar_bf16bf16f32of32_3x16 + ( + k0, + a, rs_a, cs_a_use, + b, rs_b, cs_b, + ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, + alpha, beta, + is_last_k, + post_op_c_i, post_op_c_j, + post_ops_list, rs_c_downscale + ); + } + else if ( m_partial_pieces == 2 ) + { + int cs_a_use = ( cs_a == 2 ) ? 2 : ( ( cs_a / 6 ) * 2 ); + lpgemm_rowvar_bf16bf16f32of32_2x16 + ( + k0, + a, rs_a, cs_a_use, + b, rs_b, cs_b, + ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, + alpha, beta, + is_last_k, + post_op_c_i, post_op_c_j, + post_ops_list, rs_c_downscale + ); + } + else if ( m_partial_pieces == 1 ) + { + int cs_a_use = ( cs_a == 2 ) ? 2 : ( ( cs_a / 6 ) * 1 ); + lpgemm_rowvar_bf16bf16f32of32_1x16 + ( + k0, + a, rs_a, cs_a_use, + b, rs_b, cs_b, + ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, + alpha, beta, + is_last_k, + post_op_c_i, post_op_c_j, + post_ops_list, rs_c_downscale + ); + } + } +} + +// 6x32 bf16 fringe kernel +LPGEMM_N_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x32) +{ + static void* post_ops_labels[] = + { + &&POST_OPS_6x32_DISABLE, + &&POST_OPS_BIAS_6x32, + &&POST_OPS_RELU_6x32, + &&POST_OPS_RELU_SCALE_6x32, + &&POST_OPS_DOWNSCALE_6x32 + }; + dim_t MR = 6; + dim_t m_full_pieces = m0 / MR; + dim_t m_full_pieces_loop_limit = m_full_pieces * MR; + dim_t m_partial_pieces = m0 % MR; + + dim_t k_full_pieces = k0 / 2; + dim_t k_partial_pieces = k0 % 2; + + int32_t a_kfringe_buf = 0; + + // B matrix storage bfloat type + __m512bh b0; + __m512bh b1; + + // A matrix storage bfloat type + __m512bh a_bf16_0; + + for ( dim_t ir = 0; ir < m_full_pieces_loop_limit; ir += MR ) + { + // Registers to use for accumulating C. + __m512 c_float_0p0 = _mm512_setzero_ps(); + __m512 c_float_0p1 = _mm512_setzero_ps(); + + __m512 c_float_1p0 = _mm512_setzero_ps(); + __m512 c_float_1p1 = _mm512_setzero_ps(); + + __m512 c_float_2p0 = _mm512_setzero_ps(); + __m512 c_float_2p1 = _mm512_setzero_ps(); + + __m512 c_float_3p0 = _mm512_setzero_ps(); + __m512 c_float_3p1 = _mm512_setzero_ps(); + + __m512 c_float_4p0 = _mm512_setzero_ps(); + __m512 c_float_4p1 = _mm512_setzero_ps(); + + __m512 c_float_5p0 = _mm512_setzero_ps(); + __m512 c_float_5p1 = _mm512_setzero_ps(); + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + // Load 2 rows with 32 elements each from B to 2 ZMM registers. It + // is to be noted that the B matrix is packed for use in bf16 + // instructions and each load to ZMM register will have 2 elements + // along k direction and 32 elements across n directions, so 2x16 + // elements to a ZMM register. + b0 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + b1 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 1 ) ); + + // Broadcast a[0,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-31] = a[0,kr:kr+2]*b[kr:kr+2,0-31] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + c_float_0p1 = _mm512_dpbf16_ps( c_float_0p1, a_bf16_0, b1 ); + + // Broadcast a[1,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-31] = a[1,kr:kr+2]*b[kr:kr+2,0-31] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_0, b0 ); + c_float_1p1 = _mm512_dpbf16_ps( c_float_1p1, a_bf16_0, b1 ); + + // Broadcast a[2,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[2,0-31] = a[2,kr:kr+2]*b[kr:kr+2,0-31] + c_float_2p0 = _mm512_dpbf16_ps( c_float_2p0, a_bf16_0, b0 ); + c_float_2p1 = _mm512_dpbf16_ps( c_float_2p1, a_bf16_0, b1 ); + + // Broadcast a[3,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 3 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[3,0-31] = a[3,kr:kr+2]*b[kr:kr+2,0-31] + c_float_3p0 = _mm512_dpbf16_ps( c_float_3p0, a_bf16_0, b0 ); + c_float_3p1 = _mm512_dpbf16_ps( c_float_3p1, a_bf16_0, b1 ); + + // Broadcast a[4,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 4 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[4,0-31] = a[4,kr:kr+2]*b[kr:kr+2,0-31] + c_float_4p0 = _mm512_dpbf16_ps( c_float_4p0, a_bf16_0, b0 ); + c_float_4p1 = _mm512_dpbf16_ps( c_float_4p1, a_bf16_0, b1 ); + + // Broadcast a[5,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 5 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[5,0-31] = a[5,kr:kr+2]*b[kr:kr+2,0-31] + c_float_5p0 = _mm512_dpbf16_ps( c_float_5p0, a_bf16_0, b0 ); + c_float_5p1 = _mm512_dpbf16_ps( c_float_5p1, a_bf16_0, b1 ); + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + b0 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + b1 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); + + // Broadcast a[0,kr:kr+2]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( bfloat16 ) ) + ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-31] = a[0,kr:kr+2]*b[kr:kr+2,0-31] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + c_float_0p1 = _mm512_dpbf16_ps( c_float_0p1, a_bf16_0, b1 ); + + // Broadcast a[1,kr:kr+2]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( bfloat16 ) ) + ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-31] = a[1,kr:kr+2]*b[kr:kr+2,0-31] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_0, b0 ); + c_float_1p1 = _mm512_dpbf16_ps( c_float_1p1, a_bf16_0, b1 ); + + // Broadcast a[2,kr:kr+2]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 2 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( bfloat16 ) ) + ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[2,0-31] = a[2,kr:kr+2]*b[kr:kr+2,0-31] + c_float_2p0 = _mm512_dpbf16_ps( c_float_2p0, a_bf16_0, b0 ); + c_float_2p1 = _mm512_dpbf16_ps( c_float_2p1, a_bf16_0, b1 ); + + // Broadcast a[3,kr:kr+2]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 3 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( bfloat16 ) ) + ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[3,0-31] = a[3,kr:kr+2]*b[kr:kr+2,0-31] + c_float_3p0 = _mm512_dpbf16_ps( c_float_3p0, a_bf16_0, b0 ); + c_float_3p1 = _mm512_dpbf16_ps( c_float_3p1, a_bf16_0, b1 ); + + // Broadcast a[4,kr:kr+2]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 4 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( bfloat16 ) ) + ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[4,0-31] = a[4,kr:kr+2]*b[kr:kr+2,0-31] + c_float_4p0 = _mm512_dpbf16_ps( c_float_4p0, a_bf16_0, b0 ); + c_float_4p1 = _mm512_dpbf16_ps( c_float_4p1, a_bf16_0, b1 ); + + // Broadcast a[5,kr:kr+2]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 5 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( bfloat16 ) ) + ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[5,0-31] = a[5,kr:kr+2]*b[kr:kr+2,0-31] + c_float_5p0 = _mm512_dpbf16_ps( c_float_5p0, a_bf16_0, b0 ); + c_float_5p1 = _mm512_dpbf16_ps( c_float_5p1, a_bf16_0, b1 ); + } + // Load alpha and beta + __m512 selector1 = _mm512_set1_ps( alpha ); + __m512 selector2 = _mm512_set1_ps( beta ); + + // Scale by alpha + c_float_0p0 = _mm512_mul_ps( selector1, c_float_0p0 ); + c_float_0p1 = _mm512_mul_ps( selector1, c_float_0p1 ); + + c_float_1p0 = _mm512_mul_ps( selector1, c_float_1p0 ); + c_float_1p1 = _mm512_mul_ps( selector1, c_float_1p1 ); + + c_float_2p0 = _mm512_mul_ps( selector1, c_float_2p0 ); + c_float_2p1 = _mm512_mul_ps( selector1, c_float_2p1 ); + + c_float_3p0 = _mm512_mul_ps( selector1, c_float_3p0 ); + c_float_3p1 = _mm512_mul_ps( selector1, c_float_3p1 ); + + c_float_4p0 = _mm512_mul_ps( selector1, c_float_4p0 ); + c_float_4p1 = _mm512_mul_ps( selector1, c_float_4p1 ); + + c_float_5p0 = _mm512_mul_ps( selector1, c_float_5p0 ); + c_float_5p1 = _mm512_mul_ps( selector1, c_float_5p1 ); + + // Scale C by beta. + if ( beta != 0 ) + { + // c[0,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 0 ) ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 0 ) ) + ( 1*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_0p1 = _mm512_add_ps( selector1, c_float_0p1 ); + + // c[1,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 1 ) ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + + // c[1,16-31] + selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 1 ) ) + ( 1*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_1p1 = _mm512_add_ps( selector1, c_float_1p1 ); + + // c[2,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 2 ) ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); + + // c[2,16-31] + selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 2 ) ) + ( 1*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_2p1 = _mm512_add_ps( selector1, c_float_2p1 ); + + // c[3,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 3 ) ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_3p0 = _mm512_add_ps( selector1, c_float_3p0 ); + + // c[3,16-31] + selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 3 ) ) + ( 1*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_3p1 = _mm512_add_ps( selector1, c_float_3p1 ); + + // c[4,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 4 ) ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_4p0 = _mm512_add_ps( selector1, c_float_4p0 ); + + // c[4,16-31] + selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 4 ) ) + ( 1*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_4p1 = _mm512_add_ps( selector1, c_float_4p1 ); + + // c[5,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 5 ) ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_5p0 = _mm512_add_ps( selector1, c_float_5p0 ); + + // c[5,16-31] + selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 5 ) ) + ( 1*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_5p1 = _mm512_add_ps( selector1, c_float_5p1 ); + } + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_6x32: + { + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + selector1 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 1 * 16 ) ); + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_add_ps( selector2, c_float_0p1 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + + // c[1, 16-31] + c_float_1p1 = _mm512_add_ps( selector2, c_float_1p1 ); + + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); + + // c[2, 16-31] + c_float_2p1 = _mm512_add_ps( selector2, c_float_2p1 ); + + // c[3,0-15] + c_float_3p0 = _mm512_add_ps( selector1, c_float_3p0 ); + + // c[3, 16-31] + c_float_3p1 = _mm512_add_ps( selector2, c_float_3p1 ); + + // c[4,0-15] + c_float_4p0 = _mm512_add_ps( selector1, c_float_4p0 ); + + // c[4, 16-31] + c_float_4p1 = _mm512_add_ps( selector2, c_float_4p1 ); + + // c[5,0-15] + c_float_5p0 = _mm512_add_ps( selector1, c_float_5p0 ); + + // c[5, 16-31] + c_float_5p1 = _mm512_add_ps( selector2, c_float_5p1 ); + } + else + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 1 ) ); + __m512 selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 2 ) ); + __m512 selector4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 3 ) ); + __m512 selector5 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 4 ) ); + __m512 selector6 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 5 ) ); + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_add_ps( selector1, c_float_0p1 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector2, c_float_1p0 ); + + // c[1, 16-31] + c_float_1p1 = _mm512_add_ps( selector2, c_float_1p1 ); + + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector3, c_float_2p0 ); + + // c[2, 16-31] + c_float_2p1 = _mm512_add_ps( selector3, c_float_2p1 ); + + // c[3,0-15] + c_float_3p0 = _mm512_add_ps( selector4, c_float_3p0 ); + + // c[3, 16-31] + c_float_3p1 = _mm512_add_ps( selector4, c_float_3p1 ); + + // c[4,0-15] + c_float_4p0 = _mm512_add_ps( selector5, c_float_4p0 ); + + // c[4, 16-31] + c_float_4p1 = _mm512_add_ps( selector5, c_float_4p1 ); + + // c[5,0-15] + c_float_5p0 = _mm512_add_ps( selector6, c_float_5p0 ); + + // c[5, 16-31] + c_float_5p1 = _mm512_add_ps( selector6, c_float_5p1 ); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_6x32: + { + selector1 = _mm512_setzero_ps(); + + // c[0,0-15] + c_float_0p0 = _mm512_max_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_max_ps( selector1, c_float_0p1 ); + + // c[1,0-15] + c_float_1p0 = _mm512_max_ps( selector1, c_float_1p0 ); + + // c[1,16-31] + c_float_1p1 = _mm512_max_ps( selector1, c_float_1p1 ); + + // c[2,0-15] + c_float_2p0 = _mm512_max_ps( selector1, c_float_2p0 ); + + // c[2,16-31] + c_float_2p1 = _mm512_max_ps( selector1, c_float_2p1 ); + + // c[3,0-15] + c_float_3p0 = _mm512_max_ps( selector1, c_float_3p0 ); + + // c[3,16-31] + c_float_3p1 = _mm512_max_ps( selector1, c_float_3p1 ); + + // c[4,0-15] + c_float_4p0 = _mm512_max_ps( selector1, c_float_4p0 ); + + // c[4,16-31] + c_float_4p1 = _mm512_max_ps( selector1, c_float_4p1 ); + + // c[5,0-15] + c_float_5p0 = _mm512_max_ps( selector1, c_float_5p0 ); + + // c[5,16-31] + c_float_5p1 = _mm512_max_ps( selector1, c_float_5p1 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_6x32: + { + selector1 = _mm512_setzero_ps(); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_0p0) + + // c[0, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_0p1) + + // c[1, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_1p0) + + // c[1, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_1p1) + + // c[2, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_2p0) + + // c[2, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_2p1) + + // c[3, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_3p0) + + // c[3, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_3p1) + + // c[4, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_4p0) + + // c[4, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_4p1) + + // c[5, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_5p0) + + // c[5, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_5p1) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_DOWNSCALE_6x32: + { + // c[0, 0-15] + CVT_F32_BF16(c_float_0p0,0,0); + + // c[0, 16-31] + CVT_F32_BF16(c_float_0p1,0,1); + + // c[1, 0-15] + CVT_F32_BF16(c_float_1p0,1,0); + + // c[1, 16-31] + CVT_F32_BF16(c_float_1p1,1,1); + + // c[2, 0-15] + CVT_F32_BF16(c_float_2p0,2,0); + + // c[2, 16-31] + CVT_F32_BF16(c_float_2p1,2,1); + + // c[3, 0-15] + CVT_F32_BF16(c_float_3p0,3,0); + + // c[3, 16-31] + CVT_F32_BF16(c_float_3p1,3,1); + + // c[4, 0-15] + CVT_F32_BF16(c_float_4p0,4,0); + + // c[4, 16-31] + CVT_F32_BF16(c_float_4p1,4,1); + + // c[5, 0-15] + CVT_F32_BF16(c_float_5p0,5,0); + + // c[5, 16-31] + CVT_F32_BF16(c_float_5p1,5,1); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_6x32_DISABLE: + ; + + // Store the results. + // c[0,0-15] + _mm512_storeu_ps( c + ( rs_c * ( ir + 0 ) ) + ( 0*16 ), c_float_0p0 ); + + // c[0, 16-31] + _mm512_storeu_ps( c + ( rs_c * ( ir + 0 ) ) + ( 1*16 ), c_float_0p1 ); + + // c[1,0-15] + _mm512_storeu_ps( c + ( rs_c * ( ir + 1 ) ) + ( 0*16 ), c_float_1p0 ); + + // c[1,16-31] + _mm512_storeu_ps( c + ( rs_c * ( ir + 1 ) ) + ( 1*16 ), c_float_1p1 ); + + // c[2,0-15] + _mm512_storeu_ps( c + ( rs_c * ( ir + 2 ) ) + ( 0*16 ), c_float_2p0 ); + + // c[2,16-31] + _mm512_storeu_ps( c + ( rs_c * ( ir + 2 ) ) + ( 1*16 ), c_float_2p1 ); + + // c[3,0-15] + _mm512_storeu_ps( c + ( rs_c * ( ir + 3 ) ) + ( 0*16 ), c_float_3p0 ); + + // c[3,16-31] + _mm512_storeu_ps( c + ( rs_c * ( ir + 3 ) ) + ( 1*16 ), c_float_3p1 ); + + // c[4,0-15] + _mm512_storeu_ps( c + ( rs_c * ( ir + 4 ) ) + ( 0*16 ), c_float_4p0 ); + + // c[4,16-31] + _mm512_storeu_ps( c + ( rs_c * ( ir + 4 ) ) + ( 1*16 ), c_float_4p1 ); + + // c[5,0-15] + _mm512_storeu_ps( c + ( rs_c * ( ir + 5 ) ) + ( 0*16 ), c_float_5p0 ); + + // c[5,16-31] + _mm512_storeu_ps( c + ( rs_c * ( ir + 5 ) ) + ( 1*16 ), c_float_5p1 ); + + a = a + ( MR * ps_a ); + post_op_c_i += MR; + } + + if ( m_partial_pieces > 0 ) + { + if ( m_partial_pieces == 5 ) + { + int cs_a_use = ( cs_a == 2 ) ? 2 : ( ( cs_a / 6 ) * 5 ); + lpgemm_rowvar_bf16bf16f32of32_5x32 + ( + k0, + a, rs_a, cs_a_use, + b, rs_b, cs_b, + ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, + alpha, beta, + is_last_k, + post_op_c_i, post_op_c_j, + post_ops_list, rs_c_downscale + ); + } + else if ( m_partial_pieces == 4 ) + { + int cs_a_use = ( cs_a == 2 ) ? 2 : ( ( cs_a / 6 ) * 4 ); + lpgemm_rowvar_bf16bf16f32of32_4x32 + ( + k0, + a, rs_a, cs_a_use, + b, rs_b, cs_b, + ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, + alpha, beta, + is_last_k, + post_op_c_i, post_op_c_j, + post_ops_list, rs_c_downscale + ); + } + else if ( m_partial_pieces == 3 ) + { + int cs_a_use = ( cs_a == 2 ) ? 2 : ( ( cs_a / 6 ) * 3 ); + lpgemm_rowvar_bf16bf16f32of32_3x32 + ( + k0, + a, rs_a, cs_a_use, + b, rs_b, cs_b, + ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, + alpha, beta, + is_last_k, + post_op_c_i, post_op_c_j, + post_ops_list, rs_c_downscale + ); + } + else if ( m_partial_pieces == 2 ) + { + int cs_a_use = ( cs_a == 2 ) ? 2 : ( ( cs_a / 6 ) * 2 ); + lpgemm_rowvar_bf16bf16f32of32_2x32 + ( + k0, + a, rs_a, cs_a_use, + b, rs_b, cs_b, + ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, + alpha, beta, + is_last_k, + post_op_c_i, post_op_c_j, + post_ops_list, rs_c_downscale + ); + } + else if ( m_partial_pieces == 1 ) + { + int cs_a_use = ( cs_a == 2 ) ? 2 : ( ( cs_a / 6 ) * 1 ); + lpgemm_rowvar_bf16bf16f32of32_1x32 + ( + k0, + a, rs_a, cs_a_use, + b, rs_b, cs_b, + ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, + alpha, beta, + is_last_k, + post_op_c_i, post_op_c_j, + post_ops_list, rs_c_downscale + ); + } + } +} + +// 6x48 bf16 fringe kernel +LPGEMM_N_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x48) +{ + static void* post_ops_labels[] = + { + &&POST_OPS_6x48_DISABLE, + &&POST_OPS_BIAS_6x48, + &&POST_OPS_RELU_6x48, + &&POST_OPS_RELU_SCALE_6x48, + &&POST_OPS_DOWNSCALE_6x48 + }; + dim_t MR = 6; + dim_t m_full_pieces = m0 / MR; + dim_t m_full_pieces_loop_limit = m_full_pieces * MR; + dim_t m_partial_pieces = m0 % MR; + + dim_t k_full_pieces = k0 / 2; + dim_t k_partial_pieces = k0 % 2; + + int32_t a_kfringe_buf = 0; + + // B matrix storage bfloat type + __m512bh b0; + __m512bh b1; + __m512bh b2; + + // A matrix storage bfloat type + __m512bh a_bf16_0; + + for ( dim_t ir = 0; ir < m_full_pieces_loop_limit; ir += MR ) + { + // Registers to use for accumulating C. + __m512 c_float_0p0 = _mm512_setzero_ps(); + __m512 c_float_0p1 = _mm512_setzero_ps(); + __m512 c_float_0p2 = _mm512_setzero_ps(); + + __m512 c_float_1p0 = _mm512_setzero_ps(); + __m512 c_float_1p1 = _mm512_setzero_ps(); + __m512 c_float_1p2 = _mm512_setzero_ps(); + + __m512 c_float_2p0 = _mm512_setzero_ps(); + __m512 c_float_2p1 = _mm512_setzero_ps(); + __m512 c_float_2p2 = _mm512_setzero_ps(); + + __m512 c_float_3p0 = _mm512_setzero_ps(); + __m512 c_float_3p1 = _mm512_setzero_ps(); + __m512 c_float_3p2 = _mm512_setzero_ps(); + + __m512 c_float_4p0 = _mm512_setzero_ps(); + __m512 c_float_4p1 = _mm512_setzero_ps(); + __m512 c_float_4p2 = _mm512_setzero_ps(); + + __m512 c_float_5p0 = _mm512_setzero_ps(); + __m512 c_float_5p1 = _mm512_setzero_ps(); + __m512 c_float_5p2 = _mm512_setzero_ps(); + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + + // Load 2 rows with 48 elements each from B to 3 ZMM registers. It + // is to be noted that the B matrix is packed for use in bf16 + // instructions and each load to ZMM register will have 2 elements + // along k direction and 16 elements across n directions, so 2x16 + // elements to a ZMM register. + b0 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + b1 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 1 ) ); + b2 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 2 ) ); + + // Broadcast a[0,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-47] = a[0,kr:kr+2]*b[kr:kr+2,0-47] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + c_float_0p1 = _mm512_dpbf16_ps( c_float_0p1, a_bf16_0, b1 ); + c_float_0p2 = _mm512_dpbf16_ps( c_float_0p2, a_bf16_0, b2 ); + + // Broadcast a[1,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-47] = a[1,kr:kr+2]*b[kr:kr+2,0-47] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_0, b0 ); + c_float_1p1 = _mm512_dpbf16_ps( c_float_1p1, a_bf16_0, b1 ); + c_float_1p2 = _mm512_dpbf16_ps( c_float_1p2, a_bf16_0, b2 ); + + // Broadcast a[2,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[2,0-47] = a[2,kr:kr+2]*b[kr:kr+2,0-47] + c_float_2p0 = _mm512_dpbf16_ps( c_float_2p0, a_bf16_0, b0 ); + c_float_2p1 = _mm512_dpbf16_ps( c_float_2p1, a_bf16_0, b1 ); + c_float_2p2 = _mm512_dpbf16_ps( c_float_2p2, a_bf16_0, b2 ); + + // Broadcast a[3,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 3 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[3,0-47] = a[3,kr:kr+2]*b[kr:kr+2,0-47] + c_float_3p0 = _mm512_dpbf16_ps( c_float_3p0, a_bf16_0, b0 ); + c_float_3p1 = _mm512_dpbf16_ps( c_float_3p1, a_bf16_0, b1 ); + c_float_3p2 = _mm512_dpbf16_ps( c_float_3p2, a_bf16_0, b2 ); + + // Broadcast a[4,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 4 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[4,0-47] = a[4,kr:kr+2]*b[kr:kr+2,0-47] + c_float_4p0 = _mm512_dpbf16_ps( c_float_4p0, a_bf16_0, b0 ); + c_float_4p1 = _mm512_dpbf16_ps( c_float_4p1, a_bf16_0, b1 ); + c_float_4p2 = _mm512_dpbf16_ps( c_float_4p2, a_bf16_0, b2 ); + + // Broadcast a[5,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 5 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[5,0-47] = a[5,kr:kr+2]*b[kr:kr+2,0-47] + c_float_5p0 = _mm512_dpbf16_ps( c_float_5p0, a_bf16_0, b0 ); + c_float_5p1 = _mm512_dpbf16_ps( c_float_5p1, a_bf16_0, b1 ); + c_float_5p2 = _mm512_dpbf16_ps( c_float_5p2, a_bf16_0, b2 ); + + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + b0 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + b1 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); + b2 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 2 ) ); + + // Broadcast a[0,kr:kr+4]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( bfloat16 ) ) + ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-47] = a[0,kr:kr+2]*b[kr:kr+2,0-47] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + c_float_0p1 = _mm512_dpbf16_ps( c_float_0p1, a_bf16_0, b1 ); + c_float_0p2 = _mm512_dpbf16_ps( c_float_0p2, a_bf16_0, b2 ); + + // Broadcast a[1,kr:kr+2]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( bfloat16 ) ) + ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-47] = a[1,kr:kr+2]*b[kr:kr+2,0-47] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_0, b0 ); + c_float_1p1 = _mm512_dpbf16_ps( c_float_1p1, a_bf16_0, b1 ); + c_float_1p2 = _mm512_dpbf16_ps( c_float_1p2, a_bf16_0, b2 ); + + // Broadcast a[2,kr:kr+2]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 2 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( bfloat16 ) ) + ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[2,0-47] = a[2,kr:kr+2]*b[kr:kr+2,0-47] + c_float_2p0 = _mm512_dpbf16_ps( c_float_2p0, a_bf16_0, b0 ); + c_float_2p1 = _mm512_dpbf16_ps( c_float_2p1, a_bf16_0, b1 ); + c_float_2p2 = _mm512_dpbf16_ps( c_float_2p2, a_bf16_0, b2 ); + + // Broadcast a[3,kr:kr+2]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 3 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( bfloat16 ) ) + ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[3,0-47] = a[3,kr:kr+2]*b[kr:kr+2,0-47] + c_float_3p0 = _mm512_dpbf16_ps( c_float_3p0, a_bf16_0, b0 ); + c_float_3p1 = _mm512_dpbf16_ps( c_float_3p1, a_bf16_0, b1 ); + c_float_3p2 = _mm512_dpbf16_ps( c_float_3p2, a_bf16_0, b2 ); + + // Broadcast a[4,kr:kr+2]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 4 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( bfloat16 ) ) + ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[4,0-47] = a[4,kr:kr+2]*b[kr:kr+2,0-47] + c_float_4p0 = _mm512_dpbf16_ps( c_float_4p0, a_bf16_0, b0 ); + c_float_4p1 = _mm512_dpbf16_ps( c_float_4p1, a_bf16_0, b1 ); + c_float_4p2 = _mm512_dpbf16_ps( c_float_4p2, a_bf16_0, b2 ); + + // Broadcast a[5,kr:kr+2]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 5 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( bfloat16 ) ) + ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[5,0-47] = a[5,kr:kr+2]*b[kr:kr+2,0-47] + c_float_5p0 = _mm512_dpbf16_ps( c_float_5p0, a_bf16_0, b0 ); + c_float_5p1 = _mm512_dpbf16_ps( c_float_5p1, a_bf16_0, b1 ); + c_float_5p2 = _mm512_dpbf16_ps( c_float_5p2, a_bf16_0, b2 ); + } + + // Load alpha and beta + __m512 selector1 = _mm512_set1_ps( alpha ); + __m512 selector2 = _mm512_set1_ps( beta ); + + // Scale by alpha + c_float_0p0 = _mm512_mul_ps( selector1, c_float_0p0 ); + c_float_0p1 = _mm512_mul_ps( selector1, c_float_0p1 ); + c_float_0p2 = _mm512_mul_ps( selector1, c_float_0p2 ); + + c_float_1p0 = _mm512_mul_ps( selector1, c_float_1p0 ); + c_float_1p1 = _mm512_mul_ps( selector1, c_float_1p1 ); + c_float_1p2 = _mm512_mul_ps( selector1, c_float_1p2 ); + + c_float_2p0 = _mm512_mul_ps( selector1, c_float_2p0 ); + c_float_2p1 = _mm512_mul_ps( selector1, c_float_2p1 ); + c_float_2p2 = _mm512_mul_ps( selector1, c_float_2p2 ); + + c_float_3p0 = _mm512_mul_ps( selector1, c_float_3p0 ); + c_float_3p1 = _mm512_mul_ps( selector1, c_float_3p1 ); + c_float_3p2 = _mm512_mul_ps( selector1, c_float_3p2 ); + + c_float_4p0 = _mm512_mul_ps( selector1, c_float_4p0 ); + c_float_4p1 = _mm512_mul_ps( selector1, c_float_4p1 ); + c_float_4p2 = _mm512_mul_ps( selector1, c_float_4p2 ); + + c_float_5p0 = _mm512_mul_ps( selector1, c_float_5p0 ); + c_float_5p1 = _mm512_mul_ps( selector1, c_float_5p1 ); + c_float_5p2 = _mm512_mul_ps( selector1, c_float_5p2 ); + + // Scale C by beta. + if ( beta != 0 ) + { + // c[0,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 0 ) ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 0 ) ) + ( 1*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_0p1 = _mm512_add_ps( selector1, c_float_0p1 ); + + // c[0,32-47] + selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 0 ) ) + ( 2*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_0p2 = _mm512_add_ps( selector1, c_float_0p2 ); + + // c[1,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 1 ) ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + + // c[1,16-31] + selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 1 ) ) + ( 1*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_1p1 = _mm512_add_ps( selector1, c_float_1p1 ); + + // c[1,32-47] + selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 1 ) ) + ( 2*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_1p2 = _mm512_add_ps( selector1, c_float_1p2 ); + + // c[2,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 2 ) ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); + + // c[2,16-31] + selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 2 ) ) + ( 1*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_2p1 = _mm512_add_ps( selector1, c_float_2p1 ); + + // c[2,32-47] + selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 2 ) ) + ( 2*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_2p2 = _mm512_add_ps( selector1, c_float_2p2 ); + + // c[3,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 3 ) ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_3p0 = _mm512_add_ps( selector1, c_float_3p0 ); + + // c[3,16-31] + selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 3 ) ) + ( 1*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_3p1 = _mm512_add_ps( selector1, c_float_3p1 ); + + // c[3,32-47] + selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 3 ) ) + ( 2*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_3p2 = _mm512_add_ps( selector1, c_float_3p2 ); + + // c[4,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 4 ) ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_4p0 = _mm512_add_ps( selector1, c_float_4p0 ); + + // c[4,16-31] + selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 4 ) ) + ( 1*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_4p1 = _mm512_add_ps( selector1, c_float_4p1 ); + + // c[4,32-47] + selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 4 ) ) + ( 2*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_4p2 = _mm512_add_ps( selector1, c_float_4p2 ); + + // c[5,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 5 ) ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_5p0 = _mm512_add_ps( selector1, c_float_5p0 ); + + // c[5,16-31] + selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 5 ) ) + ( 1*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_5p1 = _mm512_add_ps( selector1, c_float_5p1 ); + + // c[5,32-47] + selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 5 ) ) + ( 2*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_5p2 = _mm512_add_ps( selector1, c_float_5p2 ); + } + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_6x48: + { + __m512 selector3; + + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + selector1 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 1 * 16 ) ); + selector3 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 2 * 16 ) ); + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_add_ps( selector2, c_float_0p1 ); + + // c[0,32-47] + c_float_0p2 = _mm512_add_ps( selector3, c_float_0p2 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + + // c[1, 16-31] + c_float_1p1 = _mm512_add_ps( selector2, c_float_1p1 ); + + // c[1,32-47] + c_float_1p2 = _mm512_add_ps( selector3, c_float_1p2 ); + + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); + + // c[2, 16-31] + c_float_2p1 = _mm512_add_ps( selector2, c_float_2p1 ); + + // c[2,32-47] + c_float_2p2 = _mm512_add_ps( selector3, c_float_2p2 ); + + // c[3,0-15] + c_float_3p0 = _mm512_add_ps( selector1, c_float_3p0 ); + + // c[3, 16-31] + c_float_3p1 = _mm512_add_ps( selector2, c_float_3p1 ); + + // c[3,32-47] + c_float_3p2 = _mm512_add_ps( selector3, c_float_3p2 ); + + // c[4,0-15] + c_float_4p0 = _mm512_add_ps( selector1, c_float_4p0 ); + + // c[4, 16-31] + c_float_4p1 = _mm512_add_ps( selector2, c_float_4p1 ); + + // c[4,32-47] + c_float_4p2 = _mm512_add_ps( selector3, c_float_4p2 ); + + // c[5,0-15] + c_float_5p0 = _mm512_add_ps( selector1, c_float_5p0 ); + + // c[5, 16-31] + c_float_5p1 = _mm512_add_ps( selector2, c_float_5p1 ); + + // c[5,32-47] + c_float_5p2 = _mm512_add_ps( selector3, c_float_5p2 ); + } + else + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 1 ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 2 ) ); + __m512 selector4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 3 ) ); + __m512 selector5 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 4 ) ); + __m512 selector6 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 5 ) ); + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_add_ps( selector1, c_float_0p1 ); + + // c[0,32-47] + c_float_0p2 = _mm512_add_ps( selector1, c_float_0p2 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector2, c_float_1p0 ); + + // c[1, 16-31] + c_float_1p1 = _mm512_add_ps( selector2, c_float_1p1 ); + + // c[1,32-47] + c_float_1p2 = _mm512_add_ps( selector2, c_float_1p2 ); + + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector3, c_float_2p0 ); + + // c[2, 16-31] + c_float_2p1 = _mm512_add_ps( selector3, c_float_2p1 ); + + // c[2,32-47] + c_float_2p2 = _mm512_add_ps( selector3, c_float_2p2 ); + + // c[3,0-15] + c_float_3p0 = _mm512_add_ps( selector4, c_float_3p0 ); + + // c[3, 16-31] + c_float_3p1 = _mm512_add_ps( selector4, c_float_3p1 ); + + // c[3,32-47] + c_float_3p2 = _mm512_add_ps( selector4, c_float_3p2 ); + + // c[4,0-15] + c_float_4p0 = _mm512_add_ps( selector5, c_float_4p0 ); + + // c[4, 16-31] + c_float_4p1 = _mm512_add_ps( selector5, c_float_4p1 ); + + // c[4,32-47] + c_float_4p2 = _mm512_add_ps( selector5, c_float_4p2 ); + + // c[5,0-15] + c_float_5p0 = _mm512_add_ps( selector6, c_float_5p0 ); + + // c[5, 16-31] + c_float_5p1 = _mm512_add_ps( selector6, c_float_5p1 ); + + // c[5,32-47] + c_float_5p2 = _mm512_add_ps( selector6, c_float_5p2 ); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_6x48: + { + //printf("relu\n"); + selector1 = _mm512_setzero_ps(); + + // c[0,0-15] + c_float_0p0 = _mm512_max_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_max_ps( selector1, c_float_0p1 ); + + // c[0,32-47] + c_float_0p2 = _mm512_max_ps( selector1, c_float_0p2 ); + + // c[1,0-15] + c_float_1p0 = _mm512_max_ps( selector1, c_float_1p0 ); + + // c[1,16-31] + c_float_1p1 = _mm512_max_ps( selector1, c_float_1p1 ); + + // c[1,32-47] + c_float_1p2 = _mm512_max_ps( selector1, c_float_1p2 ); + + // c[2,0-15] + c_float_2p0 = _mm512_max_ps( selector1, c_float_2p0 ); + + // c[2,16-31] + c_float_2p1 = _mm512_max_ps( selector1, c_float_2p1 ); + + // c[2,32-47] + c_float_2p2 = _mm512_max_ps( selector1, c_float_2p2 ); + + // c[3,0-15] + c_float_3p0 = _mm512_max_ps( selector1, c_float_3p0 ); + + // c[3,16-31] + c_float_3p1 = _mm512_max_ps( selector1, c_float_3p1 ); + + // c[3,32-47] + c_float_3p2 = _mm512_max_ps( selector1, c_float_3p2 ); + + // c[4,0-15] + c_float_4p0 = _mm512_max_ps( selector1, c_float_4p0 ); + + // c[4,16-31] + c_float_4p1 = _mm512_max_ps( selector1, c_float_4p1 ); + + // c[4,32-47] + c_float_4p2 = _mm512_max_ps( selector1, c_float_4p2 ); + + // c[5,0-15] + c_float_5p0 = _mm512_max_ps( selector1, c_float_5p0 ); + + // c[5,16-31] + c_float_5p1 = _mm512_max_ps( selector1, c_float_5p1 ); + + // c[5,32-47] + c_float_5p2 = _mm512_max_ps( selector1, c_float_5p2 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_6x48: + { + selector1 = _mm512_setzero_ps(); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_0p0) + + // c[0, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_0p1) + + // c[0, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_0p2) + + // c[1, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_1p0) + + // c[1, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_1p1) + + // c[1, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_1p2) + + // c[2, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_2p0) + + // c[2, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_2p1) + + // c[2, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_2p2) + + // c[3, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_3p0) + + // c[3, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_3p1) + + // c[3, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_3p2) + + // c[4, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_4p0) + + // c[4, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_4p1) + + // c[4, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_4p2) + + // c[5, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_5p0) + + // c[5, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_5p1) + + // c[5, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_5p2) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_DOWNSCALE_6x48: + { + // c[0, 0-15] + CVT_F32_BF16(c_float_0p0,0,0); + + // c[0, 16-31] + CVT_F32_BF16(c_float_0p1,0,1); + + // c[0, 32-47] + CVT_F32_BF16(c_float_0p2,0,2); + + // c[1, 0-15] + CVT_F32_BF16(c_float_1p0,1,0); + + // c[1, 16-31] + CVT_F32_BF16(c_float_1p1,1,1); + + // c[1, 32-47] + CVT_F32_BF16(c_float_1p2,1,2); + + // c[2, 0-15] + CVT_F32_BF16(c_float_2p0,2,0); + + // c[2, 16-31] + CVT_F32_BF16(c_float_2p1,2,1); + + // c[2, 32-47] + CVT_F32_BF16(c_float_2p2,2,2); + + // c[3, 0-15] + CVT_F32_BF16(c_float_3p0,3,0); + + // c[3, 16-31] + CVT_F32_BF16(c_float_3p1,3,1); + + // c[3, 32-47] + CVT_F32_BF16(c_float_3p2,3,2); + + // c[4, 0-15] + CVT_F32_BF16(c_float_4p0,4,0); + + // c[4, 16-31] + CVT_F32_BF16(c_float_4p1,4,1); + + // c[4, 32-47] + CVT_F32_BF16(c_float_4p2,4,2); + + // c[5, 0-15] + CVT_F32_BF16(c_float_5p0,5,0); + + // c[5, 16-31] + CVT_F32_BF16(c_float_5p1,5,1); + + // c[5, 32-47] + CVT_F32_BF16(c_float_5p2,5,2); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_6x48_DISABLE: + ; + + // Store the results. + // c[0,0-15] + _mm512_storeu_ps( c + ( rs_c * ( ir + 0 ) ) + ( 0*16 ), c_float_0p0 ); + + // c[0, 16-31] + _mm512_storeu_ps( c + ( rs_c * ( ir + 0 ) ) + ( 1*16 ), c_float_0p1 ); + + // c[0,32-47] + _mm512_storeu_ps( c + ( rs_c * ( ir + 0 ) ) + ( 2*16 ), c_float_0p2 ); + + // c[1,0-15] + _mm512_storeu_ps( c + ( rs_c * ( ir + 1 ) ) + ( 0*16 ), c_float_1p0 ); + + // c[1,16-31] + _mm512_storeu_ps( c + ( rs_c * ( ir + 1 ) ) + ( 1*16 ), c_float_1p1 ); + + // c[1,32-47] + _mm512_storeu_ps( c + ( rs_c * ( ir + 1 ) ) + ( 2*16 ), c_float_1p2 ); + + // c[2,0-15] + _mm512_storeu_ps( c + ( rs_c * ( ir + 2 ) ) + ( 0*16 ), c_float_2p0 ); + + // c[2,16-31] + _mm512_storeu_ps( c + ( rs_c * ( ir + 2 ) ) + ( 1*16 ), c_float_2p1 ); + + // c[2,32-47] + _mm512_storeu_ps( c + ( rs_c * ( ir + 2 ) ) + ( 2*16 ), c_float_2p2 ); + + // c[3,0-15] + _mm512_storeu_ps( c + ( rs_c * ( ir + 3 ) ) + ( 0*16 ), c_float_3p0 ); + + // c[3,16-31] + _mm512_storeu_ps( c + ( rs_c * ( ir + 3 ) ) + ( 1*16 ), c_float_3p1 ); + + // c[3,32-47] + _mm512_storeu_ps( c + ( rs_c * ( ir + 3 ) ) + ( 2*16 ), c_float_3p2 ); + + // c[4,0-15] + _mm512_storeu_ps( c + ( rs_c * ( ir + 4 ) ) + ( 0*16 ), c_float_4p0 ); + + // c[4,16-31] + _mm512_storeu_ps( c + ( rs_c * ( ir + 4 ) ) + ( 1*16 ), c_float_4p1 ); + + // c[4,32-47] + _mm512_storeu_ps( c + ( rs_c * ( ir + 4 ) ) + ( 2*16 ), c_float_4p2 ); + + // c[5,0-15] + _mm512_storeu_ps( c + ( rs_c * ( ir + 5 ) ) + ( 0*16 ), c_float_5p0 ); + + // c[5,16-31] + _mm512_storeu_ps( c + ( rs_c * ( ir + 5 ) ) + ( 1*16 ), c_float_5p1 ); + + // c[5,32-47] + _mm512_storeu_ps( c + ( rs_c * ( ir + 5 ) ) + ( 2*16 ), c_float_5p2 ); + + a = a + ( MR * ps_a ); + post_op_c_i += MR; + + } + + if ( m_partial_pieces > 0 ) + { + if ( m_partial_pieces == 5 ) + { + int cs_a_use = ( cs_a == 2 ) ? 2 : ( ( cs_a / 6 ) * 5 ); + lpgemm_rowvar_bf16bf16f32of32_5x48 + ( + k0, + a, rs_a, cs_a_use, + b, rs_b, cs_b, + ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, + alpha, beta, + is_last_k, + post_op_c_i, post_op_c_j, + post_ops_list, rs_c_downscale + ); + } + else if ( m_partial_pieces == 4 ) + { + int cs_a_use = ( cs_a == 2 ) ? 2 : ( ( cs_a / 6 ) * 4 ); + lpgemm_rowvar_bf16bf16f32of32_4x48 + ( + k0, + a, rs_a, cs_a_use, + b, rs_b, cs_b, + ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, + alpha, beta, + is_last_k, + post_op_c_i, post_op_c_j, + post_ops_list, rs_c_downscale + ); + } + else if ( m_partial_pieces == 3 ) + { + int cs_a_use = ( cs_a == 2 ) ? 2 : ( ( cs_a / 6 ) * 3 ); + lpgemm_rowvar_bf16bf16f32of32_3x48 + ( + k0, + a, rs_a, cs_a_use, + b, rs_b, cs_b, + ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, + alpha, beta, + is_last_k, + post_op_c_i, post_op_c_j, + post_ops_list, rs_c_downscale + ); + } + else if ( m_partial_pieces == 2 ) + { + int cs_a_use = ( cs_a == 2 ) ? 2 : ( ( cs_a / 6 ) * 2 ); + lpgemm_rowvar_bf16bf16f32of32_2x48 + ( + k0, + a, rs_a, cs_a_use, + b, rs_b, cs_b, + ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, + alpha, beta, + is_last_k, + post_op_c_i, post_op_c_j, + post_ops_list, rs_c_downscale + ); + } + else if ( m_partial_pieces == 1 ) + { + int cs_a_use = ( cs_a == 2 ) ? 2 : ( ( cs_a / 6 ) * 1 ); + lpgemm_rowvar_bf16bf16f32of32_1x48 + ( + k0, + a, rs_a, cs_a_use, + b, rs_b, cs_b, + ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, + alpha, beta, + is_last_k, + post_op_c_i, post_op_c_j, + post_ops_list, rs_c_downscale + ); + } + } +} +#endif diff --git a/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_packb_bf16.h b/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_packb_bf16.h new file mode 100644 index 0000000000..07b22a5b25 --- /dev/null +++ b/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_packb_bf16.h @@ -0,0 +1,67 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef BLIS_GEMM_BF16_PACKB +#define BLIS_GEMM_BF16_PACKB + +#include "lpgemm_kernels.h" + +BLIS_INLINE dim_t get_packb_bf16bf16f32of32_min_NR() +{ + // This is the minimum NR' required for use in bf16bf16f32 kernels. The idea + // here is that since k needs to be a multiple of 2 (BF16 instr), NR'=16 + // results in total of 2 * NR' = 64 bytes to be loaded, which fits in 1 ZMM + // register. Thus the smallest n fringe kernel dimension has n=16, and thus + // any rounding for buffer sizes should be to 16. + return 16; +} + +void get_packb_nr64_bf16bf16f32of32_strides + ( + dim_t* rs_b, + dim_t* cs_b + ); + +void packb_nr64_bf16bf16f32of32 + ( + bfloat16* pack_b_buffer_bf16bf16f32of32, + const bfloat16* b, + const dim_t ldb, + const dim_t NC, + const dim_t KC, + dim_t* rs_b, + dim_t* cs_b + ); + +#endif //BLIS_GEMM_BF16_PACKB diff --git a/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_packb_bf16_amd512vnni.c b/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_packb_bf16_amd512vnni.c new file mode 100644 index 0000000000..374ac3280e --- /dev/null +++ b/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_packb_bf16_amd512vnni.c @@ -0,0 +1,506 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include + +#include "blis.h" +#include "lpgemm_config.h" +#include "aocl_bf16_type.h" + +void get_packb_nr64_bf16bf16f32of32_strides + ( + dim_t* rs_b, + dim_t* cs_b + ) +{ + *rs_b = lpgemm_get_block_size_NR_global_cntx( BF16BF16F32OF32 ) * 2; + *cs_b = lpgemm_get_block_size_NR_global_cntx( BF16BF16F32OF32 ) / 2; +} + +#ifdef BLIS_KERNELS_ZEN4 +void packb_nrlt16_bf16bf16f32of32 + ( + bfloat16* pack_b_buffer_bf16bf16f32of32, + const bfloat16* b, + const dim_t ldb, + const dim_t KC, + const dim_t n0_partial_rem + ); + +void packb_nr16_bf16bf16f32of32 + ( + bfloat16* pack_b_buffer_bf16bf16f32of32, + const bfloat16* b, + const dim_t ldb, + const dim_t KC + ); + +void packb_nr32_bf16bf16f32of32 + ( + bfloat16* pack_b_buffer_bf16bf16f32of32, + const bfloat16* b, + const dim_t ldb, + const dim_t KC + ); + +void packb_nr48_bf16bf16f32of32 + ( + bfloat16* pack_b_buffer_bf16bf16f32of32, + const bfloat16* b, + const dim_t ldb, + const dim_t KC + ); + +void packb_nr64_bf16bf16f32of32 + ( + bfloat16* pack_b_buffer_bf16bf16f32of32, + const bfloat16* b, + const dim_t ldb, + const dim_t NC, + const dim_t KC, + dim_t* rs_b, + dim_t* cs_b + ) +{ + dim_t NR = 64; + + // Used for permuting the mm512i elements for use in dpbf16_ps instruction. + __m512i selector1 = _mm512_setr_epi64(0x0, 0x1, 0x8, 0x9, 0x2, 0x3, 0xA, 0xB); + __m512i selector1_1 = _mm512_setr_epi64( 0x4, 0x5, 0xC, 0xD, 0x6, 0x7, 0xE, 0xF ); + + __m512i a0; + __m512i b0; + __m512i c0; + __m512i d0; + __m512i a01; + __m512i c01; + + dim_t n_full_pieces = NC / NR; + dim_t n_full_pieces_loop_limit = n_full_pieces * NR; + dim_t n_partial_pieces = NC % NR; + + dim_t k_full_pieces_blks = KC / 2; + dim_t k_full_pieces = k_full_pieces_blks * 2; + dim_t k_partial_pieces = KC % 2; + + // KC when not multiple of 2 will have padding to make it multiple of 2 in packed buffer. + dim_t KC_updated = KC; + if ( k_partial_pieces > 0 ) + { + KC_updated += ( 2 - k_partial_pieces ); + } + + for ( dim_t jc = 0; jc < n_full_pieces_loop_limit; jc += NR ) + { + for ( dim_t kr = 0; kr < k_full_pieces; kr += 2 ) + { + // Rearrange for dpbf16_ps, read 2 rows from B with 64 elements in each row. + a0 = _mm512_loadu_epi16( b + ( ldb * ( kr + 0 ) ) + jc ); + b0 = _mm512_loadu_epi16( b + ( ldb * ( kr + 0 ) ) + jc + 32 ); + c0 = _mm512_loadu_epi16( b + ( ldb * ( kr + 1 ) ) + jc ); + d0 = _mm512_loadu_epi16( b + ( ldb * ( kr + 1 ) ) + jc + 32 ); + + a01 = _mm512_unpacklo_epi16( a0, c0 ); + a0 = _mm512_unpackhi_epi16( a0, c0 ); + + c01 = _mm512_unpacklo_epi16( b0, d0 ); + c0 = _mm512_unpackhi_epi16( b0, d0 ); + + b0 = _mm512_permutex2var_epi64( a01, selector1, a0 ); + d0 = _mm512_permutex2var_epi64( c01, selector1, c0 ); + a0 = _mm512_permutex2var_epi64( a01, selector1_1, a0 ); + c0 = _mm512_permutex2var_epi64( c01, selector1_1, c0 ); + + //store to pack_b buffer + _mm512_storeu_epi64( pack_b_buffer_bf16bf16f32of32 + ( jc * KC_updated ) + ( ( kr + 0 ) * NR ), b0 ); + _mm512_storeu_epi64( pack_b_buffer_bf16bf16f32of32 + ( jc * KC_updated ) + ( ( kr + 0 ) * NR ) + 32, a0 ); + _mm512_storeu_epi64( pack_b_buffer_bf16bf16f32of32 + ( jc * KC_updated ) + ( ( kr + 1 ) * NR ), d0 ); + _mm512_storeu_epi64( pack_b_buffer_bf16bf16f32of32 + ( jc * KC_updated ) + ( ( kr + 1 ) * NR ) + 32, c0 ); + } + // Handle k remainder. + if( k_partial_pieces > 0) + { + a0 = _mm512_loadu_epi16( b + ( ldb * ( k_full_pieces + 0 ) ) + jc ); + b0 = _mm512_loadu_epi16( b + ( ldb * ( k_full_pieces + 0 ) ) + jc + 32 ); + c0 = _mm512_setzero_si512(); + d0 = _mm512_setzero_si512(); + + a01 = _mm512_unpacklo_epi16( a0, c0 ); + a0 = _mm512_unpackhi_epi16( a0, c0 ); + + c01 = _mm512_unpacklo_epi16( b0, d0 ); + c0 = _mm512_unpackhi_epi16( b0, d0 ); + + b0 = _mm512_permutex2var_epi64( a01, selector1, a0 ); + d0 = _mm512_permutex2var_epi64( c01, selector1, c0 ); + a0 = _mm512_permutex2var_epi64( a01, selector1_1, a0 ); + c0 = _mm512_permutex2var_epi64( c01, selector1_1, c0 ); + + //store to pack_b buffer + _mm512_storeu_epi64( pack_b_buffer_bf16bf16f32of32 + ( jc * KC_updated ) + ( ( k_full_pieces + 0 ) * NR ), b0 ); + _mm512_storeu_epi64( pack_b_buffer_bf16bf16f32of32 + ( jc * KC_updated ) + ( ( k_full_pieces + 0 ) * NR ) + 32, a0 ); + _mm512_storeu_epi64( pack_b_buffer_bf16bf16f32of32 + ( jc * KC_updated ) + ( ( k_full_pieces + 1 ) * NR ), d0 ); + _mm512_storeu_epi64( pack_b_buffer_bf16bf16f32of32 + ( jc * KC_updated ) + ( ( k_full_pieces + 1 ) * NR ) + 32, c0 ); + } + } + + if(n_partial_pieces > 0) + { + dim_t n0_partial_rem = n_partial_pieces % 16; + dim_t n0_partial_pack = 0; + + // Split into multiple smaller fringe kernels, so as to maximize + // vectorization after packing. Any n0 < NR(64) can be expressed + // as n0 = 48 + n` / n0 = 32 + n` / n0 = 16 + n`, where n` < 16. + dim_t n0_48 = n_partial_pieces / 48; + dim_t n0_32 = n_partial_pieces / 32; + dim_t n0_16 = n_partial_pieces / 16; + + if ( n0_48 == 1 ) + { + packb_nr48_bf16bf16f32of32 + ( + ( pack_b_buffer_bf16bf16f32of32 + ( n_full_pieces_loop_limit * KC_updated ) ), + ( b + n_full_pieces_loop_limit ), ldb, KC + ); + + n0_partial_pack = 48; + } + else if ( n0_32 == 1 ) + { + packb_nr32_bf16bf16f32of32 + ( + ( pack_b_buffer_bf16bf16f32of32 + ( n_full_pieces_loop_limit * KC_updated ) ), + ( b + n_full_pieces_loop_limit ), ldb, KC + ); + + n0_partial_pack = 32; + } + else if ( n0_16 == 1 ) + { + packb_nr16_bf16bf16f32of32 + ( + ( pack_b_buffer_bf16bf16f32of32 + ( n_full_pieces_loop_limit * KC_updated ) ), + ( b + n_full_pieces_loop_limit ), ldb, KC + ); + + n0_partial_pack = 16; + } + + if ( n0_partial_rem > 0 ) + { + packb_nrlt16_bf16bf16f32of32 + ( + ( pack_b_buffer_bf16bf16f32of32 + ( n_full_pieces_loop_limit * KC_updated ) + + ( n0_partial_pack * KC_updated ) ), + ( b + n_full_pieces_loop_limit + n0_partial_pack ), ldb, KC, + n0_partial_rem + ); + } + } + *rs_b = NR * 2; + *cs_b = NR / 2; +} + +void packb_nr48_bf16bf16f32of32 +( + bfloat16* pack_b_buffer_bf16bf16f32of32, + const bfloat16* b, + const dim_t ldb, + const dim_t KC + ) +{ + dim_t NR1 = 32; + dim_t NR2 = 16; + + // Used for permuting the mm512i elements for use in dpbf16_ps instruction. + __m512i selector1 = _mm512_setr_epi64(0x0, 0x1, 0x8, 0x9, 0x2, 0x3, 0xA, 0xB); + __m512i selector1_1 = _mm512_setr_epi64( 0x4, 0x5, 0xC, 0xD, 0x6, 0x7, 0xE, 0xF ); + + __m512i a0x; + __m512i b0x; + __m512i c0x; + __m512i a01x; + + __m256i a0; + __m256i b0; + __m256i c0; + __m256i a01; + + dim_t k_full_pieces_blks = KC / 2; + dim_t k_full_pieces = k_full_pieces_blks * 2; + dim_t k_partial_pieces = KC % 2; + + dim_t kr_new = 0; + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 2 ) + { + // Rearrange for dpbf16_ps, read 2 rows from B with 32 elements in each row. + a0x = _mm512_loadu_epi16( b + ( ldb * ( kr + 0 ) ) ); + c0x = _mm512_loadu_epi16( b + ( ldb * ( kr + 1 ) ) ); + + a01x = _mm512_unpacklo_epi16( a0x, c0x ); + a0x = _mm512_unpackhi_epi16( a0x, c0x ); + + b0x = _mm512_permutex2var_epi64( a01x, selector1, a0x ); + a0x = _mm512_permutex2var_epi64( a01x, selector1_1, a0x ); + + //First 2x32 elements + _mm512_storeu_epi64( pack_b_buffer_bf16bf16f32of32 + ( ( kr_new + 0 ) * NR1 ), b0x ); + _mm512_storeu_epi64( pack_b_buffer_bf16bf16f32of32 + ( ( kr_new + 1 ) * NR1 ), a0x ); + + // Rearrange for dpbf16_ps, read 2 rows from B with next 16 elements in each row. + a0 = _mm256_loadu_epi16( b + ( ldb * ( kr + 0 ) ) + NR1 ); + c0 = _mm256_loadu_epi16( b + ( ldb * ( kr + 1 ) ) + NR1 ); + + a01 = _mm256_unpacklo_epi16( a0, c0 ); + a0 = _mm256_unpackhi_epi16( a0, c0 ); + + b0 = _mm256_permute2f128_si256(a01, a0, 0x20); + a0 = _mm256_permute2f128_si256(a01, a0, 0x31); + + //Last 2x16 elements + _mm256_storeu_epi64( pack_b_buffer_bf16bf16f32of32 + ( ( kr_new + 2 ) * NR1 ), b0 ); + _mm256_storeu_epi64( pack_b_buffer_bf16bf16f32of32 + ( ( kr_new + 2 ) * NR1 ) + NR2, a0 ); + + kr_new += 3; + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + a0x = _mm512_loadu_epi16( b + ( ldb * ( k_full_pieces + 0 ) ) ); + c0x = _mm512_setzero_si512(); + + a01x = _mm512_unpacklo_epi16( a0x, c0x ); + a0x = _mm512_unpackhi_epi16( a0x, c0x ); + + b0x = _mm512_permutex2var_epi64( a01x, selector1, a0x ); + a0x = _mm512_permutex2var_epi64( a01x, selector1_1, a0x ); + + //First 2x32 elements + _mm512_storeu_epi64( pack_b_buffer_bf16bf16f32of32 + ( ( kr_new + 0 ) * NR1 ), b0x ); + _mm512_storeu_epi64( pack_b_buffer_bf16bf16f32of32 + ( ( kr_new + 1 ) * NR1 ), a0x ); + + a0 = _mm256_loadu_epi16( b + ( ldb * ( k_full_pieces + 0 ) ) + NR1 ); + c0 = _mm256_setzero_si256(); + + a01 = _mm256_unpacklo_epi16( a0, c0 ); + a0 = _mm256_unpackhi_epi16( a0, c0 ); + + b0 = _mm256_permute2f128_si256(a01, a0, 0x20); + a0 = _mm256_permute2f128_si256(a01, a0, 0x31); + + //Last 2x16 elements + _mm256_storeu_epi64( pack_b_buffer_bf16bf16f32of32 + ( ( kr_new + 2 ) * NR1 ), b0 ); + _mm256_storeu_epi64( pack_b_buffer_bf16bf16f32of32 + ( ( kr_new + 2 ) * NR1 ) + NR2, a0 ); + } +} + +void packb_nr32_bf16bf16f32of32 +( + bfloat16* pack_b_buffer_bf16bf16f32of32, + const bfloat16* b, + const dim_t ldb, + const dim_t KC + ) +{ + dim_t NR = 32; + + // Used for permuting the mm512i elements for use in dpbf16_ps instruction. + __m512i selector1 = _mm512_setr_epi64(0x0, 0x1, 0x8, 0x9, 0x2, 0x3, 0xA, 0xB); + __m512i selector1_1 = _mm512_setr_epi64( 0x4, 0x5, 0xC, 0xD, 0x6, 0x7, 0xE, 0xF ); + + __m512i a0; + __m512i b0; + __m512i c0; + __m512i a01; + + dim_t k_full_pieces_blks = KC / 2; + dim_t k_full_pieces = k_full_pieces_blks * 2; + dim_t k_partial_pieces = KC % 2; + + dim_t kr_new = 0; + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 2 ) + { + // Rearrange for dpbf16_ps, read 2 rows from B with 32 elements in each row. + a0 = _mm512_loadu_epi16( b + ( ldb * ( kr + 0 ) ) ); + c0 = _mm512_loadu_epi16( b + ( ldb * ( kr + 1 ) ) ); + + a01 = _mm512_unpacklo_epi16( a0, c0 ); + a0 = _mm512_unpackhi_epi16( a0, c0 ); + + b0 = _mm512_permutex2var_epi64( a01, selector1, a0 ); + a0 = _mm512_permutex2var_epi64( a01, selector1_1, a0 ); + + _mm512_storeu_epi64( pack_b_buffer_bf16bf16f32of32 + ( ( kr_new ) * NR ), b0 ); + _mm512_storeu_epi64( pack_b_buffer_bf16bf16f32of32 + ( ( kr_new + 1 ) * NR ), a0 ); + + kr_new += 2; + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + a0 = _mm512_loadu_epi16( b + ( ldb * ( k_full_pieces + 0 ) ) ); + c0 = _mm512_setzero_si512(); + + a01 = _mm512_unpacklo_epi16( a0, c0 ); + a0 = _mm512_unpackhi_epi16( a0, c0 ); + + b0 = _mm512_permutex2var_epi64( a01, selector1, a0 ); + a0 = _mm512_permutex2var_epi64( a01, selector1_1, a0 ); + + _mm512_storeu_epi64( pack_b_buffer_bf16bf16f32of32 + ( ( kr_new ) * NR ), b0 ); + _mm512_storeu_epi64( pack_b_buffer_bf16bf16f32of32 + ( ( kr_new + 1 ) * NR ), a0 ); + } +} + +void packb_nr16_bf16bf16f32of32 +( + bfloat16* pack_b_buffer_bf16bf16f32of32, + const bfloat16* b, + const dim_t ldb, + const dim_t KC + ) +{ + dim_t NR = 16; + + __m256i a0; + __m256i b0; + __m256i c0; + __m256i a01; + + dim_t k_full_pieces_blks = KC / 2; + dim_t k_full_pieces = k_full_pieces_blks * 2; + dim_t k_partial_pieces = KC % 2; + + dim_t kr_new = 0; + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 2 ) + { + // Rearrange for dpbf16_ps, read 2 rows from B with 16 elements in each row. + a0 = _mm256_loadu_epi16( b + ( ldb * ( kr + 0 ) ) ); + c0 = _mm256_loadu_epi16( b + ( ldb * ( kr + 1 ) ) ); + + a01 = _mm256_unpacklo_epi16( a0, c0 ); + a0 = _mm256_unpackhi_epi16( a0, c0 ); + + b0 = _mm256_permute2f128_si256(a01, a0, 0x20); + a0 = _mm256_permute2f128_si256(a01, a0, 0x31); + + _mm256_storeu_epi64( pack_b_buffer_bf16bf16f32of32 + ( ( kr_new + 0 ) * NR ), b0 ); + _mm256_storeu_epi64( pack_b_buffer_bf16bf16f32of32 + ( ( kr_new + 1 ) * NR ), a0 ); + + kr_new += 2; + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + a0 = _mm256_loadu_epi16( b + ( ldb * ( k_full_pieces + 0 ) ) ); + c0 = _mm256_setzero_si256(); + + a01 = _mm256_unpacklo_epi16( a0, c0 ); + a0 = _mm256_unpackhi_epi16( a0, c0 ); + + b0 = _mm256_permute2f128_si256(a01, a0, 0x20); + a0 = _mm256_permute2f128_si256(a01, a0, 0x31); + + _mm256_storeu_epi64( pack_b_buffer_bf16bf16f32of32 + ( ( kr_new + 0 ) * NR ), b0 ); + _mm256_storeu_epi64( pack_b_buffer_bf16bf16f32of32 + ( ( kr_new + 1 ) * NR ), a0 ); + } +} + +void packb_nrlt16_bf16bf16f32of32 +( + bfloat16* pack_b_buffer_bf16bf16f32of32, + const bfloat16* b, + const dim_t ldb, + const dim_t KC, + const dim_t n0_partial_rem + ) +{ + dim_t NR = 16; + + __m256i a0; + __m256i b0; + __m256i c0; + __m256i a01; + + dim_t k_full_pieces_blks = KC / 2; + dim_t k_full_pieces = k_full_pieces_blks * 2; + dim_t k_partial_pieces = KC % 2; + + dim_t kr_new = 0; + + bfloat16 buf0[16]; + bfloat16 buf1[16]; + + for ( int kr = 0; kr < k_full_pieces; kr += 2 ) + { + memcpy( buf0, ( b + ( ldb * ( kr + 0 ) ) ), ( n0_partial_rem * sizeof( bfloat16 ) ) ); + memcpy( buf1, ( b + ( ldb * ( kr + 1 ) ) ), ( n0_partial_rem * sizeof( bfloat16 ) ) ); + // Rearrange for dpbf16_ps, read 2 rows from B with next 16 elements in each row. + a0 = _mm256_loadu_epi16( buf0 ); + c0 = _mm256_loadu_epi16( buf1 ); + + a01 = _mm256_unpacklo_epi16( a0, c0 ); + a0 = _mm256_unpackhi_epi16( a0, c0 ); + + b0 = _mm256_permute2f128_si256(a01, a0, 0x20); + a0 = _mm256_permute2f128_si256(a01, a0, 0x31); + + _mm256_storeu_epi64( pack_b_buffer_bf16bf16f32of32 + ( ( kr_new + 0 ) * NR ), b0 ); + _mm256_storeu_epi64( pack_b_buffer_bf16bf16f32of32 + ( ( kr_new + 1 ) * NR ), a0 ); + + kr_new += 2; + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + memcpy( buf0, ( b + ( ldb * ( k_full_pieces + 0 ) ) ), ( n0_partial_rem * sizeof( bfloat16 ) ) ); + a0 = _mm256_loadu_epi16( buf0 ); + c0 = _mm256_setzero_si256(); + + a01 = _mm256_unpacklo_epi16( a0, c0 ); + a0 = _mm256_unpackhi_epi16( a0, c0 ); + + b0 = _mm256_permute2f128_si256(a01, a0, 0x20); + a0 = _mm256_permute2f128_si256(a01, a0, 0x31); + + _mm256_storeu_epi64( pack_b_buffer_bf16bf16f32of32 + ( ( kr_new + 0 ) * NR ), b0 ); + _mm256_storeu_epi64( pack_b_buffer_bf16bf16f32of32 + ( ( kr_new + 1 ) * NR ), a0 ); + } +} +#endif diff --git a/addon/aocl_gemm/kernels/lpgemm_kernels.h b/addon/aocl_gemm/kernels/lpgemm_kernels.h new file mode 100644 index 0000000000..7b73ba27e9 --- /dev/null +++ b/addon/aocl_gemm/kernels/lpgemm_kernels.h @@ -0,0 +1,265 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef BLIS_LPGEMM_KERN_H +#define BLIS_LPGEMM_KERN_H + +#include "lpgemm_post_ops.h" +#include "aocl_bf16_type.h" + +#define LPGEMM_MAIN_KERN(A_type,B_type,C_type,LP_SFX) \ +void lpgemm_rowvar_ ## LP_SFX \ + ( \ + const dim_t m0, \ + const dim_t n0, \ + const dim_t k0, \ + const A_type* a, \ + const dim_t rs_a, \ + const dim_t cs_a, \ + const dim_t ps_a, \ + const B_type* b, \ + const dim_t rs_b, \ + const dim_t cs_b, \ + C_type* c, \ + const dim_t rs_c, \ + const dim_t cs_c, \ + const C_type alpha, \ + const C_type beta, \ + bool is_last_k, \ + dim_t post_op_c_i, \ + dim_t post_op_c_j, \ + lpgemm_post_op* post_ops_list, \ + const dim_t rs_c_downscale \ + ) \ + +LPGEMM_MAIN_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6x64); +LPGEMM_MAIN_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_6x32); +LPGEMM_MAIN_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_6x64); + +#define LPGEMM_M_FRINGE_KERN(A_type,B_type,C_type,LP_SFX) \ +void lpgemm_rowvar_ ## LP_SFX \ + ( \ + const dim_t k0, \ + const A_type* a, \ + const dim_t rs_a, \ + const dim_t cs_a, \ + const B_type* b, \ + const dim_t rs_b, \ + const dim_t cs_b, \ + C_type* c, \ + const dim_t rs_c, \ + const C_type alpha, \ + const C_type beta, \ + bool is_last_k, \ + dim_t post_op_c_i, \ + dim_t post_op_c_j, \ + lpgemm_post_op* post_ops_list, \ + const dim_t rs_c_downscale \ + ) \ + +LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_5x64); +LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_4x64); +LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_3x64); +LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_2x64); +LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_1x64); + +LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_4x32); +LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_2x32); +LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_1x32); + +LPGEMM_M_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_5x64); +LPGEMM_M_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_4x64); +LPGEMM_M_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_3x64); +LPGEMM_M_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_2x64); +LPGEMM_M_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_1x64); + +#define LPGEMM_N_FRINGE_KERN(A_type,B_type,C_type,LP_SFX) \ +void lpgemm_rowvar_ ## LP_SFX \ + ( \ + const dim_t m0, \ + const dim_t k0, \ + const A_type* a, \ + const dim_t rs_a, \ + const dim_t cs_a, \ + const dim_t ps_a, \ + const B_type* b, \ + const dim_t rs_b, \ + const dim_t cs_b, \ + C_type* c, \ + const dim_t rs_c, \ + const C_type alpha, \ + const C_type beta, \ + bool is_last_k, \ + dim_t post_op_c_i, \ + dim_t post_op_c_j, \ + lpgemm_post_op* post_ops_list, \ + const dim_t rs_c_downscale \ + ) \ + +LPGEMM_N_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6x16); +LPGEMM_N_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6x32); +LPGEMM_N_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6x48); + +LPGEMM_N_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_6x16); + +LPGEMM_N_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_6x16); +LPGEMM_N_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_6x32); +LPGEMM_N_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_6x48); + +#define LPGEMM_N_LT_NR0_FRINGE_KERN(A_type,B_type,C_type,LP_SFX) \ +void lpgemm_rowvar_ ## LP_SFX \ + ( \ + const dim_t m0, \ + const dim_t k0, \ + const A_type* a, \ + const dim_t rs_a, \ + const dim_t cs_a, \ + const dim_t ps_a, \ + const B_type* b, \ + const dim_t rs_b, \ + const dim_t cs_b, \ + C_type* c, \ + const dim_t rs_c, \ + const C_type alpha, \ + const C_type beta, \ + const dim_t n0_rem, \ + bool is_last_k, \ + dim_t post_op_c_i, \ + dim_t post_op_c_j, \ + lpgemm_post_op* post_ops_list, \ + const dim_t rs_c_downscale \ + ) \ + +LPGEMM_N_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6xlt16); + +LPGEMM_N_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_6xlt16); + +LPGEMM_N_LT_NR0_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_6xlt16); + +#define LPGEMM_MN_FRINGE_KERN(A_type,B_type,C_type,LP_SFX) \ +void lpgemm_rowvar_ ## LP_SFX \ + ( \ + const dim_t k0, \ + const A_type* a, \ + const dim_t rs_a, \ + const dim_t cs_a, \ + const B_type* b, \ + const dim_t rs_b, \ + const dim_t cs_b, \ + C_type* c, \ + const dim_t rs_c, \ + const C_type alpha, \ + const C_type beta, \ + bool is_last_k, \ + dim_t post_op_c_i, \ + dim_t post_op_c_j, \ + lpgemm_post_op* post_ops_list, \ + const dim_t rs_c_downscale \ + ) \ + +LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_5x16); +LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_4x16); +LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_3x16); +LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_2x16); +LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_1x16); +LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_5x32); +LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_4x32); +LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_3x32); +LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_2x32); +LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_1x32); +LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_5x48); +LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_4x48); +LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_3x48); +LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_2x48); +LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_1x48); + +LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_4x16); +LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_2x16); +LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_1x16); + +LPGEMM_MN_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_5x16); +LPGEMM_MN_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_4x16); +LPGEMM_MN_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_3x16); +LPGEMM_MN_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_2x16); +LPGEMM_MN_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_1x16); +LPGEMM_MN_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_5x32); +LPGEMM_MN_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_4x32); +LPGEMM_MN_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_3x32); +LPGEMM_MN_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_2x32); +LPGEMM_MN_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_1x32); +LPGEMM_MN_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_5x48); +LPGEMM_MN_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_4x48); +LPGEMM_MN_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_3x48); +LPGEMM_MN_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_2x48); +LPGEMM_MN_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_1x48); + +#define LPGEMM_MN_LT_NR0_FRINGE_KERN(A_type,B_type,C_type,LP_SFX) \ +void lpgemm_rowvar_ ## LP_SFX \ + ( \ + const dim_t k0, \ + const A_type* a, \ + const dim_t rs_a, \ + const dim_t cs_a, \ + const B_type* b, \ + const dim_t rs_b, \ + const dim_t cs_b, \ + C_type* c, \ + const dim_t rs_c, \ + const C_type alpha, \ + const C_type beta, \ + const dim_t n0_rem, \ + bool is_last_k, \ + dim_t post_op_c_i, \ + dim_t post_op_c_j, \ + lpgemm_post_op* post_ops_list, \ + const dim_t rs_c_downscale \ + ) \ + +LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_5xlt16); +LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_4xlt16); +LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_3xlt16); +LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_2xlt16); +LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_1xlt16); + +LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_4xlt16); +LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_2xlt16); +LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_1xlt16); + +LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_5xlt16); +LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_4xlt16); +LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_3xlt16); +LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_2xlt16); +LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_1xlt16); + +#endif //BLIS_LPGEMM_KERN_H diff --git a/addon/aocl_gemm/kernels/u8s8s16/lpgemm_6x32rowmajor_amd256.c b/addon/aocl_gemm/kernels/u8s8s16/lpgemm_6x32rowmajor_amd256.c new file mode 100644 index 0000000000..f7ad5f2d23 --- /dev/null +++ b/addon/aocl_gemm/kernels/u8s8s16/lpgemm_6x32rowmajor_amd256.c @@ -0,0 +1,703 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ +#include +#include "blis.h" +#include "lpgemm_kernels.h" +#include "lpgemm_s16_kern_macros.h" + +// 6x32 int8o16 kernel +LPGEMM_MAIN_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_6x32) +{ + static void *post_ops_labels[] = + { + &&POST_OPS_6x32_DISABLE, + &&POST_OPS_BIAS_6x32, + &&POST_OPS_RELU_6x32, + &&POST_OPS_RELU_SCALE_6x32, + &&POST_OPS_DOWNSCALE_6x32 + }; + + dim_t MR = 6; + dim_t NR = 32; + + dim_t m_full_pieces = m0 / MR; + dim_t m_full_pieces_loop_limit = m_full_pieces * MR; + dim_t m_partial_pieces = m0 % MR; + + dim_t k_full_pieces = k0 / 2; + dim_t k_partial_pieces = k0 % 2; + + // When n fringe cases are encountered + if (n0 < NR) + { + // Split into multiple smaller fringe kernels, so as to maximize + // vectorization after packing. Any n0 < NR(32) can be expressed + // as n0 = 16 + n`. + dim_t n0_rem = n0 % 16; + dim_t n0_16 = n0 / 16; + dim_t k0_updated = k0; + + // Making multiple of 2 to suit k in vpmaddubsw + k0_updated += (k0_updated & 0x1); + + if (n0_16 == 1) + { + lpgemm_rowvar_u8s8s16o16_6x16( + m0, k0, + a, rs_a, cs_a, ps_a, + b, ((rs_b / 2) * 1), cs_b, + c, rs_c, + alpha, beta, + is_last_k, + post_op_c_i, post_op_c_j, + post_ops_list, rs_c_downscale); + + b = b + (16 * k0_updated); + c = c + 16; + post_op_c_j += 16; + } + + if (n0_rem > 0) + { + lpgemm_rowvar_u8s8s16o16_6xlt16( + m0, k0, + a, rs_a, cs_a, ps_a, + b, ((rs_b / 2) * 1), cs_b, + c, rs_c, + alpha, beta, n0_rem, + is_last_k, + post_op_c_i, post_op_c_j, + post_ops_list, rs_c_downscale); + } + + // If fringe cases are encountered, return early + return; + } + + for (dim_t ir = 0; ir < m_full_pieces_loop_limit; ir += MR) + { + + _mm256_zeroupper(); + + // Registers to use for accumulating C. + __m256i c_int16_0p0 = _mm256_setzero_si256(); + __m256i c_int16_0p1 = _mm256_setzero_si256(); + + __m256i c_int16_1p0 = _mm256_setzero_si256(); + __m256i c_int16_1p1 = _mm256_setzero_si256(); + + __m256i c_int16_2p0 = _mm256_setzero_si256(); + __m256i c_int16_2p1 = _mm256_setzero_si256(); + + __m256i c_int16_3p0 = _mm256_setzero_si256(); + __m256i c_int16_3p1 = _mm256_setzero_si256(); + + __m256i c_int16_4p0 = _mm256_setzero_si256(); + __m256i c_int16_4p1 = _mm256_setzero_si256(); + + __m256i c_int16_5p0 = _mm256_setzero_si256(); + __m256i c_int16_5p1 = _mm256_setzero_si256(); + + for (dim_t kr = 0; kr < k_full_pieces; kr += 1) + { + dim_t offset = kr * 2; + + // Broadcast a[0,kr:kr+2]. + __m256i a_int32_0 = + _mm256_set1_epi16(*(uint16_t *)(a + (rs_a * 0) + + (cs_a * offset))); + + __m256i b0 = + _mm256_loadu_si256((__m256i const *)(b + (64 * kr) + (NR * 0))); + __m256i b1 = + _mm256_loadu_si256((__m256i const *)(b + (64 * kr) + (NR * 1))); + + // Seperate register for intermediate op + __m256i inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); + + // Perform column direction mat-mul with k = 2. + // c[0,0-31] = a[0,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_0p0 = _mm256_add_epi16(inter_vec, c_int16_0p0); + + inter_vec = _mm256_maddubs_epi16(a_int32_0, b1); + c_int16_0p1 = _mm256_add_epi16(inter_vec, c_int16_0p1); + + // Broadcast a[1,kr:kr+2]. + a_int32_0 = + _mm256_set1_epi16(*(uint16_t *)(a + (rs_a * 1) + (cs_a * offset))); + + // Seperate register for intermediate op + inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); + + // Perform column direction mat-mul with k = 2. + // c[1,0-31] = a[0,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_1p0 = _mm256_add_epi16(inter_vec, c_int16_1p0); + + inter_vec = _mm256_maddubs_epi16(a_int32_0, b1); + c_int16_1p1 = _mm256_add_epi16(inter_vec, c_int16_1p1); + + // Broadcast a[2,kr:kr+2]. + a_int32_0 = + _mm256_set1_epi16(*(uint16_t *)(a + (rs_a * 2) + (cs_a * offset))); + + // Seperate register for intermediate op + inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); + // Perform column direction mat-mul with k = 2. + // c[0,0-31] = a[0,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_2p0 = _mm256_add_epi16(inter_vec, c_int16_2p0); + + inter_vec = _mm256_maddubs_epi16(a_int32_0, b1); + c_int16_2p1 = _mm256_add_epi16(inter_vec, c_int16_2p1); + + // Broadcast a[3,kr:kr+2]. + a_int32_0 = + _mm256_set1_epi16(*(uint16_t *)(a + (rs_a * 3) + (cs_a * offset))); + + // Seperate register for intermediate op + inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); + + // Perform column direction mat-mul with k = 2. + // c[0,0-31] = a[0,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_3p0 = _mm256_add_epi16(inter_vec, c_int16_3p0); + + inter_vec = _mm256_maddubs_epi16(a_int32_0, b1); + c_int16_3p1 = _mm256_add_epi16(inter_vec, c_int16_3p1); + + // Broadcast a[4,kr:kr+2]. + a_int32_0 = + _mm256_set1_epi16(*(uint16_t *)(a + (rs_a * 4) + (cs_a * offset))); + + // Seperate register for intermediate op + inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); + + // Perform column direction mat-mul with k = 2. + // c[0,0-31] = a[0,kr:kr+2]*b[kr:kr+4,0-31] + c_int16_4p0 = _mm256_add_epi16(inter_vec, c_int16_4p0); + + inter_vec = _mm256_maddubs_epi16(a_int32_0, b1); + + c_int16_4p1 = _mm256_add_epi16(inter_vec, c_int16_4p1); + + // Broadcast a[5,kr:kr+2]. + a_int32_0 = + _mm256_set1_epi16(*(uint16_t *)(a + (rs_a * 5) + (cs_a * offset))); + + // Seperate register for intermediate op + inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); + + // Perform column direction mat-mul with k = 2. + // c[0,0-31] = a[0,kr:kr+2]*b[kr:kr+4,0-31] + c_int16_5p0 = _mm256_add_epi16(inter_vec, c_int16_5p0); + + inter_vec = _mm256_maddubs_epi16(a_int32_0, b1); + c_int16_5p1 = _mm256_add_epi16(inter_vec, c_int16_5p1); + } + + // Handle k remainder. + if (k_partial_pieces > 0) + { + + __m256i b0 = _mm256_loadu_si256((__m256i const *) + (b + (64 * k_full_pieces) + (NR * 0))); + __m256i b1 = _mm256_loadu_si256((__m256i const *) + (b + (64 * k_full_pieces) + (NR * 1))); + + uint8_t a_kfringe = *(a + (rs_a * 0) + (cs_a * (k_full_pieces * 2))); + __m256i a_int32_0 = _mm256_set1_epi8(a_kfringe); + + // Seperate register for intermediate op + __m256i inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); + + // Perform column direction mat-mul with k = 2. + // c[0,0-31] = a[0,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_0p0 = _mm256_add_epi16(inter_vec, c_int16_0p0); + + inter_vec = _mm256_maddubs_epi16(a_int32_0, b1); + c_int16_0p1 = _mm256_add_epi16(inter_vec, c_int16_0p1); + + a_kfringe = *(a + (rs_a * 1) + (cs_a * (k_full_pieces * 2))); + a_int32_0 = _mm256_set1_epi8(a_kfringe); + + // Seperate register for intermediate op + inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); + + // Perform column direction mat-mul with k = 2. + // c[0,0-31] = a[0,kr:kr+2]*b[kr:kr+4,0-31] + c_int16_1p0 = _mm256_add_epi16(inter_vec, c_int16_1p0); + + inter_vec = _mm256_maddubs_epi16(a_int32_0, b1); + c_int16_1p1 = _mm256_add_epi16(inter_vec, c_int16_1p1); + + a_kfringe = *(a + (rs_a * 2) + (cs_a * (k_full_pieces * 2))); + a_int32_0 = _mm256_set1_epi8(a_kfringe); + + // Seperate register for intermediate op + inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); + + // Perform column direction mat-mul with k = 2. + // c[0,0-31] = a[0,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_2p0 = _mm256_add_epi16(inter_vec, c_int16_2p0); + + inter_vec = _mm256_maddubs_epi16(a_int32_0, b1); + + c_int16_2p1 = _mm256_add_epi16(inter_vec, c_int16_2p1); + + a_kfringe = *(a + (rs_a * 3) + (cs_a * (k_full_pieces * 2))); + a_int32_0 = _mm256_set1_epi8(a_kfringe); + + // Seperate register for intermediate op + inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); + + // Perform column direction mat-mul with k = 2. + // c[0,0-31] = a[0,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_3p0 = _mm256_add_epi16(inter_vec, c_int16_3p0); + + inter_vec = _mm256_maddubs_epi16(a_int32_0, b1); + c_int16_3p1 = _mm256_add_epi16(inter_vec, c_int16_3p1); + + a_kfringe = *(a + (rs_a * 4) + (cs_a * (k_full_pieces * 2))); + a_int32_0 = _mm256_set1_epi8(a_kfringe); + + // Seperate register for intermediate op + inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); + + // Perform column direction mat-mul with k = 2. + // c[0,0-31] = a[0,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_4p0 = _mm256_add_epi16(inter_vec, c_int16_4p0); + + inter_vec = _mm256_maddubs_epi16(a_int32_0, b1); + c_int16_4p1 = _mm256_add_epi16(inter_vec, c_int16_4p1); + + a_kfringe = *(a + (rs_a * 5) + (cs_a * (k_full_pieces * 2))); + a_int32_0 = _mm256_set1_epi8(a_kfringe); + + // Seperate register for intermediate op + inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); + + // Perform column direction mat-mul with k = 2. + // c[0,0-31] = a[0,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_5p0 = _mm256_add_epi16(inter_vec, c_int16_5p0); + + inter_vec = _mm256_maddubs_epi16(a_int32_0, b1); + c_int16_5p1 = _mm256_add_epi16(inter_vec, c_int16_5p1); + } + + // Load alpha and beta + __m256i alphav = _mm256_set1_epi16(alpha); + __m256i betav = _mm256_set1_epi16(beta); + + // Scale by alpha + c_int16_0p0 = _mm256_mullo_epi16(alphav, c_int16_0p0); + c_int16_0p1 = _mm256_mullo_epi16(alphav, c_int16_0p1); + + c_int16_1p0 = _mm256_mullo_epi16(alphav, c_int16_1p0); + c_int16_1p1 = _mm256_mullo_epi16(alphav, c_int16_1p1); + + c_int16_2p0 = _mm256_mullo_epi16(alphav, c_int16_2p0); + c_int16_2p1 = _mm256_mullo_epi16(alphav, c_int16_2p1); + + c_int16_3p0 = _mm256_mullo_epi16(alphav, c_int16_3p0); + c_int16_3p1 = _mm256_mullo_epi16(alphav, c_int16_3p1); + + c_int16_4p0 = _mm256_mullo_epi16(alphav, c_int16_4p0); + c_int16_4p1 = _mm256_mullo_epi16(alphav, c_int16_4p1); + + c_int16_5p0 = _mm256_mullo_epi16(alphav, c_int16_5p0); + c_int16_5p1 = _mm256_mullo_epi16(alphav, c_int16_5p1); + + // Scale C by beta. + if (beta != 0) + { + // c[0,0-15] + __m256i selector1 = + _mm256_loadu_si256((__m256i const *) + (c + (rs_c * (ir + 0)) + (0 * 16))); + selector1 = _mm256_mullo_epi16(betav, selector1); + c_int16_0p0 = _mm256_add_epi16(selector1, c_int16_0p0); + + // c[0, 16-31] + selector1 = + _mm256_loadu_si256((__m256i const *) + (c + (rs_c * (ir + 0)) + (1 * 16))); + selector1 = _mm256_mullo_epi16(betav, selector1); + c_int16_0p1 = _mm256_add_epi16(selector1, c_int16_0p1); + + // c[1,0-15] + selector1 = + _mm256_loadu_si256((__m256i const *) + (c + (rs_c * (ir + 1)) + (0 * 16))); + selector1 = _mm256_mullo_epi16(betav, selector1); + c_int16_1p0 = _mm256_add_epi16(selector1, c_int16_1p0); + + // c[1,16-31] + selector1 = + _mm256_loadu_si256((__m256i const *) + (c + (rs_c * (ir + 1)) + (1 * 16))); + selector1 = _mm256_mullo_epi16(betav, selector1); + c_int16_1p1 = _mm256_add_epi16(selector1, c_int16_1p1); + + // c[2,0-15] + selector1 = + _mm256_loadu_si256((__m256i const *) + (c + (rs_c * (ir + 2)) + (0 * 16))); + selector1 = _mm256_mullo_epi16(betav, selector1); + c_int16_2p0 = _mm256_add_epi16(selector1, c_int16_2p0); + + // c[2,16-31] + selector1 = + _mm256_loadu_si256((__m256i const *) + (c + (rs_c * (ir + 2)) + (1 * 16))); + selector1 = _mm256_mullo_epi16(betav, selector1); + c_int16_2p1 = _mm256_add_epi16(selector1, c_int16_2p1); + + // c[3,0-15] + selector1 = + _mm256_loadu_si256((__m256i const *) + (c + (rs_c * (ir + 3)) + (0 * 16))); + selector1 = _mm256_mullo_epi16(betav, selector1); + c_int16_3p0 = _mm256_add_epi16(selector1, c_int16_3p0); + + // c[3,16-31] + selector1 = + _mm256_loadu_si256((__m256i const *) + (c + (rs_c * (ir + 3)) + (1 * 16))); + selector1 = _mm256_mullo_epi16(betav, selector1); + c_int16_3p1 = _mm256_add_epi16(selector1, c_int16_3p1); + + // c[4,0-15] + selector1 = + _mm256_loadu_si256((__m256i const *) + (c + (rs_c * (ir + 4)) + (0 * 16))); + selector1 = _mm256_mullo_epi16(betav, selector1); + c_int16_4p0 = _mm256_add_epi16(selector1, c_int16_4p0); + + // c[4,16-31] + selector1 = + _mm256_loadu_si256((__m256i const *) + (c + (rs_c * (ir + 4)) + (1 * 16))); + selector1 = _mm256_mullo_epi16(betav, selector1); + c_int16_4p1 = _mm256_add_epi16(selector1, c_int16_4p1); + + // c[5,0-15] + selector1 = + _mm256_loadu_si256((__m256i const *) + (c + (rs_c * (ir + 5)) + (0 * 16))); + selector1 = _mm256_mullo_epi16(betav, selector1); + c_int16_5p0 = _mm256_add_epi16(selector1, c_int16_5p0); + + // c[5,16-31] + selector1 = + _mm256_loadu_si256((__m256i const *) + (c + (rs_c * (ir + 5)) + (1 * 16))); + selector1 = _mm256_mullo_epi16(betav, selector1); + c_int16_5p1 = _mm256_add_epi16(selector1, c_int16_5p1); + } + + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_6x32: + { + __m256i selector1 = + _mm256_loadu_si256( (__m256i const *)( + (int16_t *)post_ops_list_temp->op_args1 + + post_op_c_j + ( 0 * 16 )) ); + __m256i selector2 = + _mm256_loadu_si256( (__m256i const *)( + (int16_t *)post_ops_list_temp->op_args1 + + post_op_c_j + ( 1 * 16 )) ); + + // c[0,0-15] + c_int16_0p0 = _mm256_add_epi16(selector1, c_int16_0p0); + + // c[0, 16-31] + c_int16_0p1 = _mm256_add_epi16( selector2, c_int16_0p1 ); + + // c[1,0-15] + c_int16_1p0 = _mm256_add_epi16( selector1, c_int16_1p0 ); + + // c[1, 16-31] + c_int16_1p1 = _mm256_add_epi16( selector2, c_int16_1p1 ); + + // c[2,0-15] + c_int16_2p0 = _mm256_add_epi16( selector1, c_int16_2p0 ); + + // c[2, 16-31] + c_int16_2p1 = _mm256_add_epi16( selector2, c_int16_2p1 ); + + // c[3,0-15] + c_int16_3p0 = _mm256_add_epi16( selector1, c_int16_3p0 ); + + // c[3, 16-31] + c_int16_3p1 = _mm256_add_epi16( selector2, c_int16_3p1 ); + + // c[4,0-15] + c_int16_4p0 = _mm256_add_epi16( selector1, c_int16_4p0 ); + + // c[4, 16-31] + c_int16_4p1 = _mm256_add_epi16( selector2, c_int16_4p1 ); + + // c[5,0-15] + c_int16_5p0 = _mm256_add_epi16( selector1, c_int16_5p0 ); + + // c[5, 16-31] + c_int16_5p1 = _mm256_add_epi16( selector2, c_int16_5p1 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_6x32: + { + __m256i selector1 = _mm256_setzero_si256 (); + + // c[0,0-15] + c_int16_0p0 = _mm256_max_epi16( selector1, c_int16_0p0 ); + + // c[0, 16-31] + c_int16_0p1 = _mm256_max_epi16( selector1, c_int16_0p1 ); + + // c[1,0-15] + c_int16_1p0 = _mm256_max_epi16( selector1, c_int16_1p0 ); + + // c[1,16-31] + c_int16_1p1 = _mm256_max_epi16( selector1, c_int16_1p1 ); + + // c[2,0-15] + c_int16_2p0 = _mm256_max_epi16( selector1, c_int16_2p0 ); + + // c[2,16-31] + c_int16_2p1 = _mm256_max_epi16( selector1, c_int16_2p1 ); + + // c[3,0-15] + c_int16_3p0 = _mm256_max_epi16( selector1, c_int16_3p0 ); + + // c[3,16-31] + c_int16_3p1 = _mm256_max_epi16( selector1, c_int16_3p1 ); + + // c[4,0-15] + c_int16_4p0 = _mm256_max_epi16( selector1, c_int16_4p0 ); + + // c[4,16-31] + c_int16_4p1 = _mm256_max_epi16( selector1, c_int16_4p1 ); + + // c[5,0-15] + c_int16_5p0 = _mm256_max_epi16( selector1, c_int16_5p0 ); + + // c[5,16-31] + c_int16_5p1 = _mm256_max_epi16( selector1, c_int16_5p1 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_6x32: + { + __m256i selector2 = + _mm256_set1_epi16( *( ( int16_t* )post_ops_list_temp->op_args2 ) ); + + __m256i selector1, b0; + + // c[0,0-15] + RELU_SCALE_OP_S16_AVX2(c_int16_0p0) + + // c[0,16-31] + RELU_SCALE_OP_S16_AVX2(c_int16_0p1) + + // c[1,0-15] + RELU_SCALE_OP_S16_AVX2(c_int16_1p0) + + // c[1,16-31] + RELU_SCALE_OP_S16_AVX2(c_int16_1p1) + + // c[2,0-15] + RELU_SCALE_OP_S16_AVX2(c_int16_2p0) + + // c[2,16-31] + RELU_SCALE_OP_S16_AVX2(c_int16_2p1) + + // c[3,0-15] + RELU_SCALE_OP_S16_AVX2(c_int16_3p0) + + // c[3,16-31] + RELU_SCALE_OP_S16_AVX2(c_int16_3p1) + + // c[4,0-15] + RELU_SCALE_OP_S16_AVX2(c_int16_4p0) + + // c[4,16-31] + RELU_SCALE_OP_S16_AVX2(c_int16_4p1) + + // c[5,0-15] + RELU_SCALE_OP_S16_AVX2(c_int16_5p0) + + // c[5,16-31] + RELU_SCALE_OP_S16_AVX2(c_int16_5p1) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_DOWNSCALE_6x32: + { + __m128i temp[2]; + __m256i temp_32[2]; + __m256 temp_float[2]; + __m256 scale_1, scale_2; + __m256 res_1, res_2; + __m256i store_reg; + + /* Load the scale vector values into the register*/ + scale_1 = + _mm256_loadu_ps( + (float *)post_ops_list_temp->scale_factor + + post_op_c_j + (0 * 8)); + scale_2 = + _mm256_loadu_ps( + (float *)post_ops_list_temp->scale_factor + + post_op_c_j + (1 * 8)); + + BLI_MM256_S16_DOWNSCALE(c_int16_0p0, c_int16_0p1, 0); + + BLI_MM256_S16_DOWNSCALE(c_int16_1p0, c_int16_1p1, 1); + + BLI_MM256_S16_DOWNSCALE(c_int16_2p0, c_int16_2p1, 2); + + BLI_MM256_S16_DOWNSCALE(c_int16_3p0, c_int16_3p1, 3); + + BLI_MM256_S16_DOWNSCALE(c_int16_4p0, c_int16_4p1, 4); + + BLI_MM256_S16_DOWNSCALE(c_int16_5p0, c_int16_5p1, 5); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_6x32_DISABLE: + ; + + // Store the results. + // c[0,0-15] + _mm256_storeu_si256( (__m256i *)(c + ( rs_c * ( ir + 0 ) ) + ( 0*16 )), c_int16_0p0 ); + + // c[0, 16-31] + _mm256_storeu_si256( (__m256i *)(c + ( rs_c * ( ir + 0 ) ) + ( 1*16 )), c_int16_0p1 ); + + // c[1,0-15] + _mm256_storeu_si256( (__m256i *)(c + ( rs_c * ( ir + 1 ) ) + ( 0*16 )), c_int16_1p0 ); + + // c[1,16-31] + _mm256_storeu_si256( (__m256i *)(c + ( rs_c * ( ir + 1 ) ) + ( 1*16 )), c_int16_1p1 ); + + // c[2,0-15] + _mm256_storeu_si256( (__m256i *)(c + ( rs_c * ( ir + 2 ) ) + ( 0*16 )), c_int16_2p0 ); + + // c[2,16-31] + _mm256_storeu_si256( (__m256i *)(c + ( rs_c * ( ir + 2 ) ) + ( 1*16 )), c_int16_2p1 ); + + // c[3,0-15] + _mm256_storeu_si256( (__m256i *)(c + ( rs_c * ( ir + 3 ) ) + ( 0*16 )), c_int16_3p0 ); + + // c[3,16-31] + _mm256_storeu_si256( (__m256i *)(c + ( rs_c * ( ir + 3 ) ) + ( 1*16 )), c_int16_3p1 ); + + // c[4,0-15] + _mm256_storeu_si256( (__m256i *)(c + ( rs_c * ( ir + 4 ) ) + ( 0*16 )), c_int16_4p0 ); + + // c[4,16-31] + _mm256_storeu_si256( (__m256i *)(c + ( rs_c * ( ir + 4 ) ) + ( 1*16 )), c_int16_4p1 ); + + // c[5,0-15] + _mm256_storeu_si256( (__m256i *)(c + ( rs_c * ( ir + 5 ) ) + ( 0*16 )), c_int16_5p0 ); + + // c[5,16-31] + _mm256_storeu_si256( (__m256i *)(c + ( rs_c * ( ir + 5 ) ) + ( 1*16 )), c_int16_5p1 ); + + a = a + ( MR * ps_a ); + post_op_c_i += MR; + } + + if (m_partial_pieces > 0) + { + // Split into multiple smaller fringe kernels, so as to maximize + // vectorization after packing. Any m0 < MR(6) can be expressed + // as a combination of numbers from the set {4, 2, 1}. + dim_t m_partial4 = m_partial_pieces / 4; + m_partial_pieces = m_partial_pieces % 4; + + dim_t m_partial2 = m_partial_pieces / 2; + dim_t m_partial = m_partial_pieces % 2; + + if (m_partial4 == 1) + { + lpgemm_rowvar_u8s8s16o16_4x32( + k0, + a, rs_a, cs_a, + b, rs_b, cs_b, + (c + (rs_c * m_full_pieces_loop_limit)), rs_c, + alpha, beta, + is_last_k, + post_op_c_i, post_op_c_j, + post_ops_list, rs_c_downscale); + + // a pointer increment + a = a + (4 * ps_a); + m_full_pieces_loop_limit += 4; + post_op_c_i += 4; + } + + if (m_partial2 == 1) + { + lpgemm_rowvar_u8s8s16o16_2x32( + k0, + a, rs_a, cs_a, + b, rs_b, cs_b, + (c + (rs_c * m_full_pieces_loop_limit)), rs_c, + alpha, beta, + is_last_k, + post_op_c_i, post_op_c_j, + post_ops_list, rs_c_downscale); + + // a pointer increment + a = a + (2 * ps_a); + m_full_pieces_loop_limit += 2; + post_op_c_i += 2; + } + + if (m_partial == 1) + { + lpgemm_rowvar_u8s8s16o16_1x32( + k0, + a, rs_a, cs_a, + b, rs_b, cs_b, + (c + (rs_c * m_full_pieces_loop_limit)), rs_c, + alpha, beta,is_last_k, + post_op_c_i, post_op_c_j, + post_ops_list, rs_c_downscale); + post_op_c_i += 1; + } + } +} diff --git a/addon/aocl_gemm/kernels/u8s8s16/lpgemm_m_fringe_amd256.c b/addon/aocl_gemm/kernels/u8s8s16/lpgemm_m_fringe_amd256.c new file mode 100644 index 0000000000..4934b8b11c --- /dev/null +++ b/addon/aocl_gemm/kernels/u8s8s16/lpgemm_m_fringe_amd256.c @@ -0,0 +1,820 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include + +#include "blis.h" +#include "lpgemm_kernels.h" +#include "lpgemm_s16_kern_macros.h" + +// 4x32 int8o16 kernel +LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_4x32) +{ + dim_t NR = 32; + + static void *post_ops_labels[] = + { + &&POST_OPS_4x32_DISABLE, + &&POST_OPS_BIAS_4x32, + &&POST_OPS_RELU_4x32, + &&POST_OPS_RELU_SCALE_4x32, + &&POST_OPS_DOWNSCALE_4x32 + }; + + // The division is done by considering the vpmaddubsw instruction + dim_t k_full_pieces = k0 / 2; + dim_t k_partial_pieces = k0 % 2; + + // B matrix storage. + __m256i b0; + __m256i b1; + + // A matrix storage. + __m256i a_int32_0; + __m256i a_int32_1; + __m256i inter_vec[4]; + + // Registers to use for accumulating C. + __m256i c_int16_0p0 = _mm256_setzero_si256(); + __m256i c_int16_0p1 = _mm256_setzero_si256(); + + __m256i c_int16_1p0 = _mm256_setzero_si256(); + __m256i c_int16_1p1 = _mm256_setzero_si256(); + + __m256i c_int16_2p0 = _mm256_setzero_si256(); + __m256i c_int16_2p1 = _mm256_setzero_si256(); + + __m256i c_int16_3p0 = _mm256_setzero_si256(); + __m256i c_int16_3p1 = _mm256_setzero_si256(); + + for (dim_t kr = 0; kr < k_full_pieces; kr += 1) + { + dim_t offset = kr * 2; + + // Broadcast a[0,kr:kr+2]. + a_int32_0 = _mm256_set1_epi16(*(uint16_t *)(a + (rs_a * 0) + (cs_a * offset))); + + b0 = _mm256_loadu_si256((__m256i const *)(b + (64 * kr) + (NR * 0))); + b1 = _mm256_loadu_si256((__m256i const *)(b + (64 * kr) + (NR * 1))); + + // Broadcast a[1,kr:kr+2]. + a_int32_1 = _mm256_set1_epi16(*(uint16_t *)(a + (rs_a * 1) + (cs_a * offset))); + + // Seperate register for intermediate op + inter_vec[0] = _mm256_maddubs_epi16(a_int32_0, b0); + inter_vec[1] = _mm256_maddubs_epi16(a_int32_0, b1); + + // Perform column direction mat-mul with k = 2. + // c[0,0-31] = a[0,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_0p0 = _mm256_add_epi16(inter_vec[0], c_int16_0p0); + c_int16_0p1 = _mm256_add_epi16(inter_vec[1], c_int16_0p1); + + // Broadcast a[2,kr:kr+2]. + a_int32_0 = _mm256_set1_epi16(*(uint16_t *)(a + (rs_a * 2) + (cs_a * offset))); + + // Seperate register for intermediate op + inter_vec[2] = _mm256_maddubs_epi16(a_int32_1, b0); + inter_vec[3] = _mm256_maddubs_epi16(a_int32_1, b1); + + // Perform column direction mat-mul with k = 2. + // c[1,0-31] = a[1,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_1p0 = _mm256_add_epi16(inter_vec[2], c_int16_1p0); + c_int16_1p1 = _mm256_add_epi16(inter_vec[3], c_int16_1p1); + + // Broadcast a[3,kr:kr+2]. + a_int32_1 = _mm256_set1_epi16(*(uint16_t *)(a + (rs_a * 3) + (cs_a * offset))); + + // Seperate register for intermediate op + inter_vec[0] = _mm256_maddubs_epi16(a_int32_0, b0); + inter_vec[1] = _mm256_maddubs_epi16(a_int32_0, b1); + + // Perform column direction mat-mul with k = 2. + // c[2,0-31] = a[2,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_2p0 = _mm256_add_epi16(inter_vec[0], c_int16_2p0); + c_int16_2p1 = _mm256_add_epi16(inter_vec[1], c_int16_2p1); + + // Seperate register for intermediate op + inter_vec[2] = _mm256_maddubs_epi16(a_int32_1, b0); + inter_vec[3] = _mm256_maddubs_epi16(a_int32_1, b1); + + // Perform column direction mat-mul with k = 2. + // c[3,0-31] = a[3,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_3p0 = _mm256_add_epi16(inter_vec[2], c_int16_3p0); + c_int16_3p1 = _mm256_add_epi16(inter_vec[3], c_int16_3p1); + } + + // Handle k remainder. + if (k_partial_pieces > 0) + { + uint8_t a_kfringe; + + b0 = _mm256_loadu_si256((__m256i const *)(b + (64 * k_full_pieces) + (NR * 0))); + b1 = _mm256_loadu_si256((__m256i const *)(b + (64 * k_full_pieces) + (NR * 1))); + + a_kfringe = *(a + (rs_a * 0) + (cs_a * (k_full_pieces * 2))); + a_int32_0 = _mm256_set1_epi8(a_kfringe); + + // Seperate register for intermediate op + inter_vec[0] = _mm256_maddubs_epi16(a_int32_0, b0); + inter_vec[1] = _mm256_maddubs_epi16(a_int32_0, b1); + + // Perform column direction mat-mul with k = 2. + // c[0,0-31] = a[0,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_0p0 = _mm256_add_epi16(inter_vec[0], c_int16_0p0); + c_int16_0p1 = _mm256_add_epi16(inter_vec[1], c_int16_0p1); + + a_kfringe = *(a + (rs_a * 1) + (cs_a * (k_full_pieces * 2))); + a_int32_1 = _mm256_set1_epi8(a_kfringe); + + // Seperate register for intermediate op + inter_vec[2] = _mm256_maddubs_epi16(a_int32_1, b0); + inter_vec[3] = _mm256_maddubs_epi16(a_int32_1, b1); + + // Perform column direction mat-mul with k = 2. + // c[1,0-31] = a[1,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_1p0 = _mm256_add_epi16(inter_vec[2], c_int16_1p0); + c_int16_1p1 = _mm256_add_epi16(inter_vec[3], c_int16_1p1); + + a_kfringe = *(a + (rs_a * 2) + (cs_a * (k_full_pieces * 2))); + a_int32_0 = _mm256_set1_epi8(a_kfringe); + + // Seperate register for intermediate op + inter_vec[0] = _mm256_maddubs_epi16(a_int32_0, b0); + inter_vec[1] = _mm256_maddubs_epi16(a_int32_0, b1); + + // Perform column direction mat-mul with k = 4. + // c[2,0-63] = a[2,kr:kr+4]*b[kr:kr+4,0-63] + c_int16_2p0 = _mm256_add_epi16(inter_vec[0], c_int16_2p0); + c_int16_2p1 = _mm256_add_epi16(inter_vec[1], c_int16_2p1); + + a_kfringe = *(a + (rs_a * 3) + (cs_a * (k_full_pieces * 2))); + a_int32_1 = _mm256_set1_epi8(a_kfringe); + + // Seperate register for intermediate op + inter_vec[2] = _mm256_maddubs_epi16(a_int32_1, b0); + inter_vec[3] = _mm256_maddubs_epi16(a_int32_1, b1); + + // Perform column direction mat-mul with k = 2. + // c[3,0-31] = a[3,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_3p0 = _mm256_add_epi16(inter_vec[2], c_int16_3p0); + c_int16_3p1 = _mm256_add_epi16(inter_vec[3], c_int16_3p1); + } + + // Load alpha and beta + __m256i selector1 = _mm256_set1_epi16(alpha); + __m256i selector2 = _mm256_set1_epi16(beta); + + // Scale by alpha + c_int16_0p0 = _mm256_mullo_epi16(selector1, c_int16_0p0); + c_int16_0p1 = _mm256_mullo_epi16(selector1, c_int16_0p1); + + c_int16_1p0 = _mm256_mullo_epi16(selector1, c_int16_1p0); + c_int16_1p1 = _mm256_mullo_epi16(selector1, c_int16_1p1); + + c_int16_2p0 = _mm256_mullo_epi16(selector1, c_int16_2p0); + c_int16_2p1 = _mm256_mullo_epi16(selector1, c_int16_2p1); + + c_int16_3p0 = _mm256_mullo_epi16(selector1, c_int16_3p0); + c_int16_3p1 = _mm256_mullo_epi16(selector1, c_int16_3p1); + + // Scale C by beta. + if (beta != 0) + { + // c[0,0-15] + selector1 = _mm256_loadu_si256((__m256i const *)(c + (rs_c * 0) + (0 * 16))); + selector1 = _mm256_mullo_epi16(selector2, selector1); + c_int16_0p0 = _mm256_add_epi16(selector1, c_int16_0p0); + + // c[0, 16-31] + selector1 = _mm256_loadu_si256((__m256i const *)(c + (rs_c * 0) + (1 * 16))); + selector1 = _mm256_mullo_epi16(selector2, selector1); + c_int16_0p1 = _mm256_add_epi16(selector1, c_int16_0p1); + + // c[1,0-15] + selector1 = _mm256_loadu_si256((__m256i const *)(c + (rs_c * 1) + (0 * 16))); + selector1 = _mm256_mullo_epi16(selector2, selector1); + c_int16_1p0 = _mm256_add_epi16(selector1, c_int16_1p0); + + // c[1,16-31] + selector1 = _mm256_loadu_si256((__m256i const *)(c + (rs_c * 1) + (1 * 16))); + selector1 = _mm256_mullo_epi16(selector2, selector1); + c_int16_1p1 = _mm256_add_epi16(selector1, c_int16_1p1); + + // c[2,0-15] + selector1 = _mm256_loadu_si256((__m256i const *)(c + (rs_c * 2) + (0 * 16))); + selector1 = _mm256_mullo_epi16(selector2, selector1); + c_int16_2p0 = _mm256_add_epi16(selector1, c_int16_2p0); + + // c[2,16-31] + selector1 = _mm256_loadu_si256((__m256i const *)(c + (rs_c * 2) + (1 * 16))); + selector1 = _mm256_mullo_epi16(selector2, selector1); + c_int16_2p1 = _mm256_add_epi16(selector1, c_int16_2p1); + + // c[3,0-15] + selector1 = _mm256_loadu_si256((__m256i const *)(c + (rs_c * 3) + (0 * 16))); + selector1 = _mm256_mullo_epi16(selector2, selector1); + c_int16_3p0 = _mm256_add_epi16(selector1, c_int16_3p0); + + // c[3,16-31] + selector1 = _mm256_loadu_si256((__m256i const *)(c + (rs_c * 3) + (1 * 16))); + selector1 = _mm256_mullo_epi16(selector2, selector1); + c_int16_3p1 = _mm256_add_epi16(selector1, c_int16_3p1); + } + + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_4x32: + { + selector1 = + _mm256_loadu_si256( (__m256i const *)((int16_t *)post_ops_list_temp->op_args1 + + post_op_c_j + ( 0 * 16 )) ); + selector2 = + _mm256_loadu_si256( (__m256i const *)((int16_t *)post_ops_list_temp->op_args1 + + post_op_c_j + ( 1 * 16 )) ); + + // c[0,0-15] + c_int16_0p0 = _mm256_add_epi16( selector1, c_int16_0p0 ); + + // c[0, 16-31] + c_int16_0p1 = _mm256_add_epi16( selector2, c_int16_0p1 ); + + // c[1,0-15] + c_int16_1p0 = _mm256_add_epi16( selector1, c_int16_1p0 ); + + // c[1, 16-31] + c_int16_1p1 = _mm256_add_epi16( selector2, c_int16_1p1 ); + + // c[2,0-15] + c_int16_2p0 = _mm256_add_epi16( selector1, c_int16_2p0 ); + + // c[2, 16-31] + c_int16_2p1 = _mm256_add_epi16( selector2, c_int16_2p1 ); + + // c[3,0-15] + c_int16_3p0 = _mm256_add_epi16( selector1, c_int16_3p0 ); + + // c[3, 16-31] + c_int16_3p1 = _mm256_add_epi16( selector2, c_int16_3p1 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_4x32: + { + selector1 = _mm256_setzero_si256 (); + + // c[0,0-15] + c_int16_0p0 = _mm256_max_epi16( selector1, c_int16_0p0 ); + + // c[0, 16-31] + c_int16_0p1 = _mm256_max_epi16( selector1, c_int16_0p1 ); + + // c[1,0-15] + c_int16_1p0 = _mm256_max_epi16( selector1, c_int16_1p0 ); + + // c[1,16-31] + c_int16_1p1 = _mm256_max_epi16( selector1, c_int16_1p1 ); + + // c[2,0-15] + c_int16_2p0 = _mm256_max_epi16( selector1, c_int16_2p0 ); + + // c[2,16-31] + c_int16_2p1 = _mm256_max_epi16( selector1, c_int16_2p1 ); + + // c[3,0-15] + c_int16_3p0 = _mm256_max_epi16( selector1, c_int16_3p0 ); + + // c[3,16-31] + c_int16_3p1 = _mm256_max_epi16( selector1, c_int16_3p1 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_4x32: + { + selector2 = + _mm256_set1_epi16( *( ( int16_t* )post_ops_list_temp->op_args2 ) ); + + // c[0,0-15] + RELU_SCALE_OP_S16_AVX2(c_int16_0p0) + + // c[0,16-31] + RELU_SCALE_OP_S16_AVX2(c_int16_0p1) + + // c[1,0-15] + RELU_SCALE_OP_S16_AVX2(c_int16_1p0) + + // c[1,16-31] + RELU_SCALE_OP_S16_AVX2(c_int16_1p1) + + // c[2,0-15] + RELU_SCALE_OP_S16_AVX2(c_int16_2p0) + + // c[2,16-31] + RELU_SCALE_OP_S16_AVX2(c_int16_2p1) + + // c[3,0-15] + RELU_SCALE_OP_S16_AVX2(c_int16_3p0) + + // c[3,16-31] + RELU_SCALE_OP_S16_AVX2(c_int16_3p1) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_DOWNSCALE_4x32: + { + __m128i temp[2]; + __m256i temp_32[2]; + __m256 temp_float[2]; + __m256 scale_1, scale_2; + __m256 res_1, res_2; + __m256i store_reg; + + /* Load the scale vector values into the register*/ + scale_1 = + _mm256_loadu_ps( + (float *)post_ops_list_temp->scale_factor + + post_op_c_j + (0 * 8)); + scale_2 = + _mm256_loadu_ps( + (float *)post_ops_list_temp->scale_factor + + post_op_c_j + (1 * 8)); + + BLI_MM256_S16_DOWNSCALE(c_int16_0p0, c_int16_0p1, 0); + + BLI_MM256_S16_DOWNSCALE(c_int16_1p0, c_int16_1p1, 1); + + BLI_MM256_S16_DOWNSCALE(c_int16_2p0, c_int16_2p1, 2); + + BLI_MM256_S16_DOWNSCALE(c_int16_3p0, c_int16_3p1, 3); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_4x32_DISABLE: + ; + + // Store the results. + // c[0,0-15] + _mm256_storeu_si256( (__m256i *)(c + ( rs_c * 0 ) + ( 0*16 )), c_int16_0p0 ); + + // c[0, 16-31] + _mm256_storeu_si256( (__m256i *)(c + ( rs_c * 0 ) + ( 1*16 )), c_int16_0p1 ); + + // c[1,0-15] + _mm256_storeu_si256( (__m256i *)(c + ( rs_c * 1 ) + ( 0*16 )), c_int16_1p0 ); + + // c[1,16-31] + _mm256_storeu_si256( (__m256i *)(c + ( rs_c * 1 ) + ( 1*16 )), c_int16_1p1 ); + + // c[2,0-15] + _mm256_storeu_si256( (__m256i *)(c + ( rs_c * 2 ) + ( 0*16 )), c_int16_2p0 ); + + // c[2,16-31] + _mm256_storeu_si256( (__m256i *)(c + ( rs_c * 2 ) + ( 1*16 )), c_int16_2p1 ); + + // c[3,0-15] + _mm256_storeu_si256( (__m256i *)(c + ( rs_c * 3 ) + ( 0*16 )), c_int16_3p0 ); + + // c[3,16-31] + _mm256_storeu_si256( (__m256i *)(c + ( rs_c * 3 ) + ( 1*16 )), c_int16_3p1 ); +} + + +// 2x32 int8o16 kernel +LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_2x32) +{ + dim_t NR = 32; + + static void *post_ops_labels[] = + { + &&POST_OPS_2x32_DISABLE, + &&POST_OPS_BIAS_2x32, + &&POST_OPS_RELU_2x32, + &&POST_OPS_RELU_SCALE_2x32, + &&POST_OPS_DOWNSCALE_2x32 + }; + + // The division is done by considering the vpmaddubsw instruction + dim_t k_full_pieces = k0 / 2; + dim_t k_partial_pieces = k0 % 2; + + // B matrix storage. + __m256i b0; + __m256i b1; + + // A matrix storage. + __m256i a_int32_0; + __m256i a_int32_1; + __m256i inter_vec[4]; + + // Registers to use for accumulating C. + __m256i c_int16_0p0 = _mm256_setzero_si256(); + __m256i c_int16_0p1 = _mm256_setzero_si256(); + + __m256i c_int16_1p0 = _mm256_setzero_si256(); + __m256i c_int16_1p1 = _mm256_setzero_si256(); + + for (dim_t kr = 0; kr < k_full_pieces; kr += 1) + { + dim_t offset = kr * 2; + + // Broadcast a[0,kr:kr+2]. + a_int32_0 = _mm256_set1_epi16(*(uint16_t *)(a + (rs_a * 0) + (cs_a * offset))); + + b0 = _mm256_loadu_si256((__m256i const *)(b + (64 * kr) + (NR * 0))); + b1 = _mm256_loadu_si256((__m256i const *)(b + (64 * kr) + (NR * 1))); + + // Broadcast a[1,kr:kr+2]. + a_int32_1 = _mm256_set1_epi16(*(uint16_t *)(a + (rs_a * 1) + (cs_a * offset))); + + // Seperate register for intermediate op + inter_vec[0] = _mm256_maddubs_epi16(a_int32_0, b0); + inter_vec[1] = _mm256_maddubs_epi16(a_int32_0, b1); + + // Perform column direction mat-mul with k = 2. + // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_0p0 = _mm256_add_epi16(inter_vec[0], c_int16_0p0); + c_int16_0p1 = _mm256_add_epi16(inter_vec[1], c_int16_0p1); + + // Seperate register for intermediate op + inter_vec[2] = _mm256_maddubs_epi16(a_int32_1, b0); + inter_vec[3] = _mm256_maddubs_epi16(a_int32_1, b1); + + // Perform column direction mat-mul with k = 2. + // c[1,0-15] = a[1,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_1p0 = _mm256_add_epi16(inter_vec[2], c_int16_1p0); + c_int16_1p1 = _mm256_add_epi16(inter_vec[3], c_int16_1p1); + } + // Handle k remainder. + if (k_partial_pieces > 0) + { + uint8_t a_kfringe; + + b0 = _mm256_loadu_si256((__m256i const *)(b + (64 * k_full_pieces) + (NR * 0))); + b1 = _mm256_loadu_si256((__m256i const *)(b + (64 * k_full_pieces) + (NR * 1))); + + a_kfringe = *(a + (rs_a * 0) + (cs_a * (k_full_pieces * 2))); + a_int32_0 = _mm256_set1_epi8(a_kfringe); + + // Seperate register for intermediate op + inter_vec[0] = _mm256_maddubs_epi16(a_int32_0, b0); + inter_vec[1] = _mm256_maddubs_epi16(a_int32_0, b1); + + // Perform column direction mat-mul with k = 2. + // c[0,0-31] = a[0,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_0p0 = _mm256_add_epi16(inter_vec[0], c_int16_0p0); + c_int16_0p1 = _mm256_add_epi16(inter_vec[1], c_int16_0p1); + + a_kfringe = *(a + (rs_a * 1) + (cs_a * (k_full_pieces * 2))); + a_int32_1 = _mm256_set1_epi8(a_kfringe); + + // Seperate register for intermediate op + inter_vec[2] = _mm256_maddubs_epi16(a_int32_1, b0); + inter_vec[3] = _mm256_maddubs_epi16(a_int32_1, b1); + + // Perform column direction mat-mul with k = 2. + // c[1,0-31] = a[1,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_1p0 = _mm256_add_epi16(inter_vec[2], c_int16_1p0); + c_int16_1p1 = _mm256_add_epi16(inter_vec[3], c_int16_1p1); + } + + // Load alpha and beta + __m256i selector1 = _mm256_set1_epi16(alpha); + __m256i selector2 = _mm256_set1_epi16(beta); + + // Scale by alpha + c_int16_0p0 = _mm256_mullo_epi16(selector1, c_int16_0p0); + c_int16_0p1 = _mm256_mullo_epi16(selector1, c_int16_0p1); + + c_int16_1p0 = _mm256_mullo_epi16(selector1, c_int16_1p0); + c_int16_1p1 = _mm256_mullo_epi16(selector1, c_int16_1p1); + + // Scale C by beta. + if (beta != 0) + { + // c[0,0-15] + selector1 = _mm256_loadu_si256((__m256i const *)(c + (rs_c * 0) + (0 * 16))); + selector1 = _mm256_mullo_epi16(selector2, selector1); + c_int16_0p0 = _mm256_add_epi16(selector1, c_int16_0p0); + + // c[0, 16-31] + selector1 = _mm256_loadu_si256((__m256i const *)(c + (rs_c * 0) + (1 * 16))); + selector1 = _mm256_mullo_epi16(selector2, selector1); + c_int16_0p1 = _mm256_add_epi16(selector1, c_int16_0p1); + + // c[1,0-15] + selector1 = _mm256_loadu_si256((__m256i const *)(c + (rs_c * 1) + (0 * 16))); + selector1 = _mm256_mullo_epi16(selector2, selector1); + c_int16_1p0 = _mm256_add_epi16(selector1, c_int16_1p0); + + // c[1,16-31] + selector1 = _mm256_loadu_si256((__m256i const *)(c + (rs_c * 1) + (1 * 16))); + selector1 = _mm256_mullo_epi16(selector2, selector1); + c_int16_1p1 = _mm256_add_epi16(selector1, c_int16_1p1); + } + + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_2x32: + { + selector1 = + _mm256_loadu_si256( (__m256i const *)((int16_t *)post_ops_list_temp->op_args1 + + post_op_c_j + ( 0 * 16 )) ); + selector2 = + _mm256_loadu_si256( (__m256i const *)((int16_t *)post_ops_list_temp->op_args1 + + post_op_c_j + ( 1 * 16 )) ); + + // c[0,0-15] + c_int16_0p0 = _mm256_add_epi16( selector1, c_int16_0p0 ); + + // c[0, 16-31] + c_int16_0p1 = _mm256_add_epi16( selector2, c_int16_0p1 ); + + // c[1,0-15] + c_int16_1p0 = _mm256_add_epi16( selector1, c_int16_1p0 ); + + // c[1, 16-31] + c_int16_1p1 = _mm256_add_epi16( selector2, c_int16_1p1 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_2x32: + { + selector1 = _mm256_setzero_si256 (); + + // c[0,0-15] + c_int16_0p0 = _mm256_max_epi16( selector1, c_int16_0p0 ); + + // c[0, 16-31] + c_int16_0p1 = _mm256_max_epi16( selector1, c_int16_0p1 ); + + // c[1,0-15] + c_int16_1p0 = _mm256_max_epi16( selector1, c_int16_1p0 ); + + // c[1,16-31] + c_int16_1p1 = _mm256_max_epi16( selector1, c_int16_1p1 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_2x32: + { + selector2 = + _mm256_set1_epi16( *( ( int16_t* )post_ops_list_temp->op_args2 ) ); + + // c[0,0-15] + RELU_SCALE_OP_S16_AVX2(c_int16_0p0) + + // c[0,16-31] + RELU_SCALE_OP_S16_AVX2(c_int16_0p1) + + // c[1,0-15] + RELU_SCALE_OP_S16_AVX2(c_int16_1p0) + + // c[1,16-31] + RELU_SCALE_OP_S16_AVX2(c_int16_1p1) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_DOWNSCALE_2x32: + { + __m128i temp[2]; + __m256i temp_32[2]; + __m256 temp_float[2]; + __m256 scale_1, scale_2; + __m256 res_1, res_2; + __m256i store_reg; + + /* Load the scale vector values into the register*/ + scale_1 = + _mm256_loadu_ps( + (float *)post_ops_list_temp->scale_factor + + post_op_c_j + (0 * 8)); + scale_2 = + _mm256_loadu_ps( + (float *)post_ops_list_temp->scale_factor + + post_op_c_j + (1 * 8)); + + BLI_MM256_S16_DOWNSCALE(c_int16_0p0, c_int16_0p1, 0); + + BLI_MM256_S16_DOWNSCALE(c_int16_1p0, c_int16_1p1, 1); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_2x32_DISABLE: + ; + + // Store the results. + // c[0,0-15] + _mm256_storeu_si256( (__m256i *)(c + ( rs_c * 0 ) + ( 0*16 )), c_int16_0p0 ); + + // c[0, 16-31] + _mm256_storeu_si256( (__m256i *)(c + ( rs_c * 0 ) + ( 1*16 )), c_int16_0p1 ); + + // c[1,0-15] + _mm256_storeu_si256( (__m256i *)(c + ( rs_c * 1 ) + ( 0*16 )), c_int16_1p0 ); + + // c[1,16-31] + _mm256_storeu_si256( (__m256i *)(c + ( rs_c * 1 ) + ( 1*16 )), c_int16_1p1 ); +} + +// 1x32 int8o16 kernel +LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_1x32) +{ + dim_t NR = 32; + + static void *post_ops_labels[] = + { + &&POST_OPS_1x32_DISABLE, + &&POST_OPS_BIAS_1x32, + &&POST_OPS_RELU_1x32, + &&POST_OPS_RELU_SCALE_1x32, + &&POST_OPS_DOWNSCALE_1x32 + }; + + // The division is done by considering the vpmaddubsw instruction + dim_t k_full_pieces = k0 / 2; + dim_t k_partial_pieces = k0 % 2; + + // B matrix storage. + __m256i b0; + __m256i b1; + + // A matrix storage. + __m256i a_int32_0; + __m256i inter_vec[2]; + + // Registers to use for accumulating C. + __m256i c_int16_0p0 = _mm256_setzero_si256(); + __m256i c_int16_0p1 = _mm256_setzero_si256(); + + for (dim_t kr = 0; kr < k_full_pieces; kr += 1) + { + dim_t offset = kr * 2; + + // Broadcast a[0,kr:kr+2]. + a_int32_0 = _mm256_set1_epi16(*(uint16_t *)(a + (rs_a * 0) + (cs_a * offset))); + + b0 = _mm256_loadu_si256((__m256i const *)(b + (64 * kr) + (NR * 0))); + b1 = _mm256_loadu_si256((__m256i const *)(b + (64 * kr) + (NR * 1))); + + // Seperate register for intermediate op + inter_vec[0] = _mm256_maddubs_epi16(a_int32_0, b0); + inter_vec[1] = _mm256_maddubs_epi16(a_int32_0, b1); + + // Perform column direction mat-mul with k = 2. + // c[0,0-31] = a[0,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_0p0 = _mm256_add_epi16(inter_vec[0], c_int16_0p0); + c_int16_0p1 = _mm256_add_epi16(inter_vec[1], c_int16_0p1); + } + // Handle k remainder. + if (k_partial_pieces > 0) + { + uint8_t a_kfringe; + + b0 = _mm256_loadu_si256((__m256i const *)(b + (64 * k_full_pieces) + (NR * 0))); + b1 = _mm256_loadu_si256((__m256i const *)(b + (64 * k_full_pieces) + (NR * 1))); + + a_kfringe = *(a + (rs_a * 0) + (cs_a * (k_full_pieces * 2))); + a_int32_0 = _mm256_set1_epi8(a_kfringe); + + // Seperate register for intermediate op + inter_vec[0] = _mm256_maddubs_epi16(a_int32_0, b0); + inter_vec[1] = _mm256_maddubs_epi16(a_int32_0, b1); + + // Perform column direction mat-mul with k = 2. + // c[0,0-31] = a[0,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_0p0 = _mm256_add_epi16(inter_vec[0], c_int16_0p0); + c_int16_0p1 = _mm256_add_epi16(inter_vec[1], c_int16_0p1); + } + + // Load alpha and beta + __m256i selector1 = _mm256_set1_epi16(alpha); + __m256i selector2 = _mm256_set1_epi16(beta); + + // Scale by alpha + c_int16_0p0 = _mm256_mullo_epi16(selector1, c_int16_0p0); + c_int16_0p1 = _mm256_mullo_epi16(selector1, c_int16_0p1); + + // Scale C by beta. + if (beta != 0) + { + // c[0,0-15] + selector1 = _mm256_loadu_si256((__m256i const *)(c + (rs_c * 0) + (0 * 16))); + selector1 = _mm256_mullo_epi16(selector2, selector1); + c_int16_0p0 = _mm256_add_epi16(selector1, c_int16_0p0); + + // c[0, 16-31] + selector1 = _mm256_loadu_si256((__m256i const *)(c + (rs_c * 0) + (1 * 16))); + selector1 = _mm256_mullo_epi16(selector2, selector1); + c_int16_0p1 = _mm256_add_epi16(selector1, c_int16_0p1); + } + + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_1x32: + { + selector1 = + _mm256_loadu_si256( (__m256i const *)((int16_t *)post_ops_list_temp->op_args1 + + post_op_c_j + ( 0 * 16 )) ); + selector2 = + _mm256_loadu_si256( (__m256i const *)((int16_t *)post_ops_list_temp->op_args1 + + post_op_c_j + ( 1 * 16 )) ); + + // c[0,0-15] + c_int16_0p0 = _mm256_add_epi16( selector1, c_int16_0p0 ); + + // c[0, 16-31] + c_int16_0p1 = _mm256_add_epi16( selector2, c_int16_0p1 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_1x32: + { + selector1 = _mm256_setzero_si256 (); + + // c[0,0-15] + c_int16_0p0 = _mm256_max_epi16( selector1, c_int16_0p0 ); + + // c[0, 16-31] + c_int16_0p1 = _mm256_max_epi16( selector1, c_int16_0p1 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_1x32: + { + selector2 = + _mm256_set1_epi16( *( ( int16_t* )post_ops_list_temp->op_args2 ) ); + + // c[0,0-15] + RELU_SCALE_OP_S16_AVX2(c_int16_0p0) + + // c[0,16-31] + RELU_SCALE_OP_S16_AVX2(c_int16_0p1) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_DOWNSCALE_1x32: + { + __m128i temp[2]; + __m256i temp_32[2]; + __m256 temp_float[2]; + __m256 scale_1, scale_2; + __m256 res_1, res_2; + __m256i store_reg; + + /* Load the scale vector values into the register*/ + scale_1 = + _mm256_loadu_ps( + (float *)post_ops_list_temp->scale_factor + + post_op_c_j + (0 * 8)); + scale_2 = + _mm256_loadu_ps( + (float *)post_ops_list_temp->scale_factor + + post_op_c_j + (1 * 8)); + + BLI_MM256_S16_DOWNSCALE(c_int16_0p0, c_int16_0p1, 0); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_1x32_DISABLE: + ; + + // Store the results. + // c[0,0-15] + _mm256_storeu_si256( (__m256i *)(c + ( rs_c * 0 ) + ( 0*16 )), c_int16_0p0 ); + + // c[0, 16-31] + _mm256_storeu_si256( (__m256i *)(c + ( rs_c * 0 ) + ( 1*16 )), c_int16_0p1 ); +} diff --git a/addon/aocl_gemm/kernels/u8s8s16/lpgemm_mn_fringe_amd256.c b/addon/aocl_gemm/kernels/u8s8s16/lpgemm_mn_fringe_amd256.c new file mode 100644 index 0000000000..f24455036e --- /dev/null +++ b/addon/aocl_gemm/kernels/u8s8s16/lpgemm_mn_fringe_amd256.c @@ -0,0 +1,1266 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include + +#include "blis.h" +#include "lpgemm_kernels.h" +#include "lpgemm_s16_kern_macros.h" + +// 4x32 int8o16 kernel +LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_4x16) +{ + dim_t NR = 16; + + static void *post_ops_labels[] = + { + &&POST_OPS_4x16_DISABLE, + &&POST_OPS_BIAS_4x16, + &&POST_OPS_RELU_4x16, + &&POST_OPS_RELU_SCALE_4x16, + &&POST_OPS_DOWNSCALE_4x16 + }; + + // The division is done by considering the vpmaddubsw instruction + dim_t k_full_pieces = k0 / 2; + dim_t k_partial_pieces = k0 % 2; + + // B matrix storage. + __m256i b0; + + // A matrix storage. + __m256i a_int32_0; + __m256i inter_vec; + + // Registers to use for accumulating C. + __m256i c_int16_0p0 = _mm256_setzero_si256(); + __m256i c_int16_1p0 = _mm256_setzero_si256(); + __m256i c_int16_2p0 = _mm256_setzero_si256(); + __m256i c_int16_3p0 = _mm256_setzero_si256(); + + for (dim_t kr = 0; kr < k_full_pieces; kr += 1) + { + dim_t offset = kr * 2; + + b0 = _mm256_loadu_si256((__m256i const *)(b + (32 * kr) + (NR * 0))); + + // Broadcast a[0,kr:kr+2]. + a_int32_0 = _mm256_set1_epi16(*(uint16_t *)(a + (rs_a * 0) + (cs_a * offset))); + + // Seperate register for intermediate op + inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); + + // Perform column direction mat-mul with k = 2. + // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_0p0 = _mm256_add_epi16(inter_vec, c_int16_0p0); + + // Broadcast a[1,kr:kr+2]. + a_int32_0 = _mm256_set1_epi16(*(uint16_t *)(a + (rs_a * 1) + (cs_a * offset))); + + // Seperate register for intermediate op + inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); + + // Perform column direction mat-mul with k = 2. + // c[1,0-15] = a[1,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_1p0 = _mm256_add_epi16(inter_vec, c_int16_1p0); + + // Broadcast a[2,kr:kr+2]. + a_int32_0 = _mm256_set1_epi16(*(uint16_t *)(a + (rs_a * 2) + (cs_a * offset))); + + // Seperate register for intermediate op + inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); + + // Perform column direction mat-mul with k = 2. + // c[2,0-15] = a[2,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_2p0 = _mm256_add_epi16(inter_vec, c_int16_2p0); + + // Broadcast a[3,kr:kr+2]. + a_int32_0 = _mm256_set1_epi16(*(uint16_t *)(a + (rs_a * 3) + (cs_a * offset))); + + // Seperate register for intermediate op + inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); + + // Perform column direction mat-mul with k = 2. + // c[3,0-31] = a[3,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_3p0 = _mm256_add_epi16(inter_vec, c_int16_3p0); + } + + // Handle k remainder. + if (k_partial_pieces > 0) + { + uint8_t a_kfringe; + + b0 = _mm256_loadu_si256((__m256i const *)(b + (32 * k_full_pieces) + (NR * 0))); + + a_kfringe = *(a + (rs_a * 0) + (cs_a * (k_full_pieces * 2))); + a_int32_0 = _mm256_set1_epi8(a_kfringe); + + // Seperate register for intermediate op + inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); + + // Perform column direction mat-mul with k = 2. + // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_0p0 = _mm256_add_epi16(inter_vec, c_int16_0p0); + + a_kfringe = *(a + (rs_a * 1) + (cs_a * (k_full_pieces * 2))); + a_int32_0 = _mm256_set1_epi8(a_kfringe); + + // Seperate register for intermediate op + inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); + + // Perform column direction mat-mul with k = 2. + // c[1,0-15] = a[1,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_1p0 = _mm256_add_epi16(inter_vec, c_int16_1p0); + + a_kfringe = *(a + (rs_a * 2) + (cs_a * (k_full_pieces * 2))); + a_int32_0 = _mm256_set1_epi8(a_kfringe); + + // Seperate register for intermediate op + inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); + + // Perform column direction mat-mul with k = 2. + // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_2p0 = _mm256_add_epi16(inter_vec, c_int16_2p0); + + a_kfringe = *(a + (rs_a * 3) + (cs_a * (k_full_pieces * 2))); + a_int32_0 = _mm256_set1_epi8(a_kfringe); + + // Seperate register for intermediate op + inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); + + // Perform column direction mat-mul with k = 2. + // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_3p0 = _mm256_add_epi16(inter_vec, c_int16_3p0); + } + + // Load alpha and beta + __m256i selector1 = _mm256_set1_epi16(alpha); + __m256i selector2 = _mm256_set1_epi16(beta); + + // Scale by alpha + c_int16_0p0 = _mm256_mullo_epi16(selector1, c_int16_0p0); + + c_int16_1p0 = _mm256_mullo_epi16(selector1, c_int16_1p0); + + c_int16_2p0 = _mm256_mullo_epi16(selector1, c_int16_2p0); + + c_int16_3p0 = _mm256_mullo_epi16(selector1, c_int16_3p0); + + // Scale C by beta. + if (beta != 0) + { + // c[0,0-15] + selector1 = _mm256_loadu_si256((__m256i const *)(c + (rs_c * 0) + (0 * 16))); + selector1 = _mm256_mullo_epi16(selector2, selector1); + c_int16_0p0 = _mm256_add_epi16(selector1, c_int16_0p0); + + // c[1,0-15] + selector1 = _mm256_loadu_si256((__m256i const *)(c + (rs_c * 1) + (0 * 16))); + selector1 = _mm256_mullo_epi16(selector2, selector1); + c_int16_1p0 = _mm256_add_epi16(selector1, c_int16_1p0); + + // c[2,0-15] + selector1 = _mm256_loadu_si256((__m256i const *)(c + (rs_c * 2) + (0 * 16))); + selector1 = _mm256_mullo_epi16(selector2, selector1); + c_int16_2p0 = _mm256_add_epi16(selector1, c_int16_2p0); + + // c[3,0-15] + selector1 = _mm256_loadu_si256((__m256i const *)(c + (rs_c * 3) + (0 * 16))); + selector1 = _mm256_mullo_epi16(selector2, selector1); + c_int16_3p0 = _mm256_add_epi16(selector1, c_int16_3p0); + } + + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_4x16: + { + selector1 = + _mm256_loadu_si256( (__m256i const *)((int16_t *)post_ops_list_temp->op_args1 + + post_op_c_j + ( 0 * 16 )) ); + + // c[0,0-15] + c_int16_0p0 = _mm256_add_epi16( selector1, c_int16_0p0 ); + + // c[1,0-15] + c_int16_1p0 = _mm256_add_epi16( selector1, c_int16_1p0 ); + + // c[2,0-15] + c_int16_2p0 = _mm256_add_epi16( selector1, c_int16_2p0 ); + + // c[3,0-15] + c_int16_3p0 = _mm256_add_epi16( selector1, c_int16_3p0 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_4x16: + { + selector1 = _mm256_setzero_si256 (); + + // c[0,0-15] + c_int16_0p0 = _mm256_max_epi16( selector1, c_int16_0p0 ); + + // c[1,0-15] + c_int16_1p0 = _mm256_max_epi16( selector1, c_int16_1p0 ); + + // c[2,0-15] + c_int16_2p0 = _mm256_max_epi16( selector1, c_int16_2p0 ); + + // c[3,0-15] + c_int16_3p0 = _mm256_max_epi16( selector1, c_int16_3p0 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_4x16: + { + selector2 = + _mm256_set1_epi16( *( ( int16_t* )post_ops_list_temp->op_args2 ) ); + + // c[0,0-15] + RELU_SCALE_OP_S16_AVX2(c_int16_0p0) + + // c[1,0-15] + RELU_SCALE_OP_S16_AVX2(c_int16_1p0) + + // c[2,0-15] + RELU_SCALE_OP_S16_AVX2(c_int16_2p0) + + // c[3,0-15] + RELU_SCALE_OP_S16_AVX2(c_int16_3p0) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_DOWNSCALE_4x16: + { + __m128i temp[2]; + __m256i temp_32[2]; + __m256 temp_float[2]; + __m256 scale_1, scale_2; + __m256 res_1, res_2; + __m256i store_reg; + + /* Load the scale vector values into the register*/ + scale_1 = + _mm256_loadu_ps( + (float *)post_ops_list_temp->scale_factor + + post_op_c_j + (0 * 8)); + scale_2 = + _mm256_loadu_ps( + (float *)post_ops_list_temp->scale_factor + + post_op_c_j + (1 * 8)); + + BLI_MM256_S16_DOWNSCALE2(c_int16_0p0, c_int16_1p0, 0, 1); + + BLI_MM256_S16_DOWNSCALE2(c_int16_2p0, c_int16_3p0, 2, 3); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_4x16_DISABLE: + ; + + // Store the results. + // c[0,0-15] + _mm256_storeu_si256((__m256i *)(c + (rs_c * 0) + (0 * 16)), c_int16_0p0); + + // c[1,0-15] + _mm256_storeu_si256((__m256i *)(c + (rs_c * 1) + (0 * 16)), c_int16_1p0); + + // c[2,0-15] + _mm256_storeu_si256((__m256i *)(c + (rs_c * 2) + (0 * 16)), c_int16_2p0); + + // c[3,0-15] + _mm256_storeu_si256((__m256i *)(c + (rs_c * 3) + (0 * 16)), c_int16_3p0); +} + +// 4x16 int8o16 kernel +LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_4xlt16) +{ + dim_t NR = 16; + + static void *post_ops_labels[] = + { + &&POST_OPS_4xlt16_DISABLE, + &&POST_OPS_BIAS_4xlt16, + &&POST_OPS_RELU_4xlt16, + &&POST_OPS_RELU_SCALE_4xlt16, + &&POST_OPS_DOWNSCALE_4xlt16 + }; + + // The division is done by considering the vpmaddubsw instruction + dim_t k_full_pieces = k0 / 2; + dim_t k_partial_pieces = k0 % 2; + + // B matrix storage. + __m256i b0; + + // A matrix storage. + __m256i a_int32_0; + __m256i inter_vec; + + int16_t buf0[16]; + int16_t buf1[16]; + int16_t buf2[16]; + int16_t buf3[16]; + + // Registers to use for accumulating C. + __m256i c_int16_0p0 = _mm256_setzero_si256(); + + __m256i c_int16_1p0 = _mm256_setzero_si256(); + + __m256i c_int16_2p0 = _mm256_setzero_si256(); + + __m256i c_int16_3p0 = _mm256_setzero_si256(); + + for (dim_t kr = 0; kr < k_full_pieces; kr += 1) + { + dim_t offset = kr * 2; + + b0 = _mm256_loadu_si256((__m256i const *)(b + (32 * kr) + (NR * 0))); + + // Broadcast a[0,kr:kr+2]. + a_int32_0 = _mm256_set1_epi16(*(uint16_t *)(a + (rs_a * 0) + (cs_a * offset))); + + // Seperate register for intermediate op + inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); + + // Perform column direction mat-mul with k = 2. + // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_0p0 = _mm256_add_epi16(inter_vec, c_int16_0p0); + + // Broadcast a[1,kr:kr+2]. + a_int32_0 = _mm256_set1_epi16(*(uint16_t *)(a + (rs_a * 1) + (cs_a * offset))); + + // Seperate register for intermediate op + inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); + + // Perform column direction mat-mul with k = 2. + // c[1,0-15] = a[1,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_1p0 = _mm256_add_epi16(inter_vec, c_int16_1p0); + + // Broadcast a[2,kr:kr+2]. + a_int32_0 = _mm256_set1_epi16(*(uint16_t *)(a + (rs_a * 2) + (cs_a * offset))); + + // Seperate register for intermediate op + inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); + + // Perform column direction mat-mul with k = 2. + // c[2,0-15] = a[2,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_2p0 = _mm256_add_epi16(inter_vec, c_int16_2p0); + + // Broadcast a[3,kr:kr+2]. + a_int32_0 = _mm256_set1_epi16(*(uint16_t *)(a + (rs_a * 3) + (cs_a * offset))); + + // Seperate register for intermediate op + inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); + + // Perform column direction mat-mul with k = 2. + // c[3,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_3p0 = _mm256_add_epi16(inter_vec, c_int16_3p0); + } + + // Handle k remainder. + if (k_partial_pieces > 0) + { + uint8_t a_kfringe; + + b0 = _mm256_loadu_si256((__m256i const *)(b + (32 * k_full_pieces) + (NR * 0))); + + a_kfringe = *(a + (rs_a * 0) + (cs_a * (k_full_pieces * 2))); + a_int32_0 = _mm256_set1_epi8(a_kfringe); + + // Seperate register for intermediate op + inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); + + // Perform column direction mat-mul with k = 2. + // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_0p0 = _mm256_add_epi16(inter_vec, c_int16_0p0); + + a_kfringe = *(a + (rs_a * 1) + (cs_a * (k_full_pieces * 2))); + a_int32_0 = _mm256_set1_epi8(a_kfringe); + + // Seperate register for intermediate op + inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); + + // Perform column direction mat-mul with k = 2. + // c[1,0-15] = a[1,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_1p0 = _mm256_add_epi16(inter_vec, c_int16_1p0); + + a_kfringe = *(a + (rs_a * 2) + (cs_a * (k_full_pieces * 2))); + a_int32_0 = _mm256_set1_epi8(a_kfringe); + + // Seperate register for intermediate op + inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); + + // Perform column direction mat-mul with k = 2. + // c[2,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_2p0 = _mm256_add_epi16(inter_vec, c_int16_2p0); + + a_kfringe = *(a + (rs_a * 3) + (cs_a * (k_full_pieces * 2))); + a_int32_0 = _mm256_set1_epi8(a_kfringe); + + // Seperate register for intermediate op + inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); + + // Perform column direction mat-mul with k = 2. + // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_3p0 = _mm256_add_epi16(inter_vec, c_int16_3p0); + } + + // Load alpha and beta + __m256i selector1 = _mm256_set1_epi16(alpha); + __m256i selector2 = _mm256_set1_epi16(beta); + + // Scale by alpha + c_int16_0p0 = _mm256_mullo_epi16(selector1, c_int16_0p0); + + c_int16_1p0 = _mm256_mullo_epi16(selector1, c_int16_1p0); + + c_int16_2p0 = _mm256_mullo_epi16(selector1, c_int16_2p0); + + c_int16_3p0 = _mm256_mullo_epi16(selector1, c_int16_3p0); + + // Scale C by beta. + if (beta != 0) + { + memcpy(buf0, (c + (rs_c * 0)), (n0_rem * sizeof(int16_t))); + memcpy(buf1, (c + (rs_c * 1)), (n0_rem * sizeof(int16_t))); + memcpy(buf2, (c + (rs_c * 2)), (n0_rem * sizeof(int16_t))); + memcpy(buf3, (c + (rs_c * 3)), (n0_rem * sizeof(int16_t))); + + // c[0,0-15] + selector1 = _mm256_loadu_si256((__m256i const *)buf0); + selector1 = _mm256_mullo_epi16(selector2, selector1); + c_int16_0p0 = _mm256_add_epi16(selector1, c_int16_0p0); + + // c[1,0-15] + selector1 = _mm256_loadu_si256((__m256i const *)buf1); + selector1 = _mm256_mullo_epi16(selector2, selector1); + c_int16_1p0 = _mm256_add_epi16(selector1, c_int16_1p0); + + // c[2,0-15] + selector1 = _mm256_loadu_si256((__m256i const *)buf2); + selector1 = _mm256_mullo_epi16(selector2, selector1); + c_int16_2p0 = _mm256_add_epi16(selector1, c_int16_2p0); + + // c[3,0-15] + selector1 = _mm256_loadu_si256((__m256i const *)buf3); + selector1 = _mm256_mullo_epi16(selector2, selector1); + c_int16_3p0 = _mm256_add_epi16(selector1, c_int16_3p0); + } + + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_4xlt16: + { + memcpy( buf0, ( ( int16_t* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 0 * 16 ) ), ( n0_rem * sizeof( int16_t ) ) ); + + selector1 = + _mm256_loadu_si256( (__m256i const *) buf0 ); + + // c[0,0-15] + c_int16_0p0 = _mm256_add_epi16( selector1, c_int16_0p0 ); + + // c[1,0-15] + c_int16_1p0 = _mm256_add_epi16( selector1, c_int16_1p0 ); + + // c[2,0-15] + c_int16_2p0 = _mm256_add_epi16( selector1, c_int16_2p0 ); + + // c[3,0-15] + c_int16_3p0 = _mm256_add_epi16( selector1, c_int16_3p0 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_4xlt16: + { + selector1 = _mm256_setzero_si256 (); + + // c[0,0-15] + c_int16_0p0 = _mm256_max_epi16( selector1, c_int16_0p0 ); + + // c[1,0-15] + c_int16_1p0 = _mm256_max_epi16( selector1, c_int16_1p0 ); + + // c[2,0-15] + c_int16_2p0 = _mm256_max_epi16( selector1, c_int16_2p0 ); + + // c[3,0-15] + c_int16_3p0 = _mm256_max_epi16( selector1, c_int16_3p0 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_4xlt16: + { + selector2 = + _mm256_set1_epi16( *( ( int16_t* )post_ops_list_temp->op_args2 ) ); + + // c[0,0-15] + RELU_SCALE_OP_S16_AVX2(c_int16_0p0) + + // c[1,0-15] + RELU_SCALE_OP_S16_AVX2(c_int16_1p0) + + // c[2,0-15] + RELU_SCALE_OP_S16_AVX2(c_int16_2p0) + + // c[3,0-15] + RELU_SCALE_OP_S16_AVX2(c_int16_3p0) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_DOWNSCALE_4xlt16: + { + __m128i temp[2]; + __m256i temp_32[2]; + __m256 temp_float[2]; + __m256 scale_1, scale_2; + __m256 res_1, res_2; + __m256i store_reg; + + float float_buf[16]; + int8_t store_buf[16]; + + memcpy( float_buf, ( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j ), ( n0_rem * sizeof( float ) ) ); + + // Load the scale vector values into the register + scale_1 = _mm256_loadu_ps(float_buf + (0 * 8)); + scale_2 = _mm256_loadu_ps(float_buf + (1 * 8)); + + BLI_MM256_S16_DOWNSCALE2_LT16(c_int16_0p0, c_int16_1p0, 0, 1) + + BLI_MM256_S16_DOWNSCALE2_LT16(c_int16_2p0, c_int16_3p0, 2, 3) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_4xlt16_DISABLE: + ; + + // c[0,0-15] + _mm256_storeu_si256((__m256i_u *)buf0, c_int16_0p0); + + // c[1,0-15] + _mm256_storeu_si256((__m256i_u *)buf1, c_int16_1p0); + + // c[2,0-15] + _mm256_storeu_si256((__m256i_u *)buf2, c_int16_2p0); + + // c[3,0-15] + _mm256_storeu_si256((__m256i_u *)buf3, c_int16_3p0); + + memcpy(c + (rs_c * 0) + (0 * 16), buf0, (n0_rem * sizeof(int16_t))); + + // c[1,0-15] + memcpy(c + (rs_c * +1) + (0 * 16), buf1, (n0_rem * sizeof(int16_t))); + + // c[2,0-15] + memcpy(c + (rs_c * +2) + (0 * 16), buf2, (n0_rem * sizeof(int16_t))); + + // c[3,0-15] + memcpy(c + (rs_c * +3) + (0 * 16), buf3, (n0_rem * sizeof(int16_t))); +} + +// 2x16 int8o16 kernel +LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_2x16) +{ + dim_t NR = 16; + + static void *post_ops_labels[] = + { + &&POST_OPS_2x16_DISABLE, + &&POST_OPS_BIAS_2x16, + &&POST_OPS_RELU_2x16, + &&POST_OPS_RELU_SCALE_2x16, + &&POST_OPS_DOWNSCALE_2x16 + }; + + // The division is done by considering the vpmaddubsw instruction + dim_t k_full_pieces = k0 / 2; + dim_t k_partial_pieces = k0 % 2; + + // B matrix storage. + __m256i b0; + + // A matrix storage. + __m256i a_int32_0; + __m256i inter_vec; + + // Registers to use for accumulating C. + __m256i c_int16_0p0 = _mm256_setzero_si256(); + + __m256i c_int16_1p0 = _mm256_setzero_si256(); + + for (dim_t kr = 0; kr < k_full_pieces; kr += 1) + { + dim_t offset = kr * 2; + + b0 = _mm256_loadu_si256((__m256i const *)(b + (32 * kr) + (NR * 0))); + + // Broadcast a[0,kr:kr+2]. + a_int32_0 = _mm256_set1_epi16(*(uint16_t *)(a + (rs_a * 0) + (cs_a * offset))); + + // Seperate register for intermediate op + inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); + + // Perform column direction mat-mul with k = 2. + // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_0p0 = _mm256_add_epi16(inter_vec, c_int16_0p0); + + // Broadcast a[1,kr:kr+2]. + a_int32_0 = _mm256_set1_epi16(*(uint16_t *)(a + (rs_a * 1) + (cs_a * offset))); + + // Seperate register for intermediate op + inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); + + // Perform column direction mat-mul with k = 2. + // c[1,0-15] = a[1,kr:kr+4]*b[kr:kr+2,0-31] + c_int16_1p0 = _mm256_add_epi16(inter_vec, c_int16_1p0); + } + // Handle k remainder. + if (k_partial_pieces > 0) + { + uint8_t a_kfringe; + + b0 = _mm256_loadu_si256((__m256i const *)(b + (32 * k_full_pieces) + (NR * 0))); + + a_kfringe = *(a + (rs_a * 0) + (cs_a * (k_full_pieces * 2))); + a_int32_0 = _mm256_set1_epi8(a_kfringe); + + // Seperate register for intermediate op + inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); + + // Perform column direction mat-mul with k = 2. + // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_0p0 = _mm256_add_epi16(inter_vec, c_int16_0p0); + + a_kfringe = *(a + (rs_a * 1) + (cs_a * (k_full_pieces * 2))); + a_int32_0 = _mm256_set1_epi8(a_kfringe); + + // Seperate register for intermediate op + inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); + + // Perform column direction mat-mul with k = 2. + // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_1p0 = _mm256_add_epi16(inter_vec, c_int16_1p0); + } + + // Load alpha and beta + __m256i selector1 = _mm256_set1_epi16(alpha); + __m256i selector2 = _mm256_set1_epi16(beta); + + // Scale by alpha + c_int16_0p0 = _mm256_mullo_epi16(selector1, c_int16_0p0); + + c_int16_1p0 = _mm256_mullo_epi16(selector1, c_int16_1p0); + + // Scale C by beta. + if (beta != 0) + { + // c[0,0-15] + selector1 = _mm256_loadu_si256((__m256i const *)(c + (rs_c * 0) + (0 * 16))); + selector1 = _mm256_mullo_epi16(selector2, selector1); + c_int16_0p0 = _mm256_add_epi16(selector1, c_int16_0p0); + + // c[1,0-15] + selector1 = _mm256_loadu_si256((__m256i const *)(c + (rs_c * 1) + (0 * 16))); + selector1 = _mm256_mullo_epi16(selector2, selector1); + c_int16_1p0 = _mm256_add_epi16(selector1, c_int16_1p0); + } + + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_2x16: + { + selector1 = + _mm256_loadu_si256( (__m256i const *)((int16_t *)post_ops_list_temp->op_args1 + + post_op_c_j + ( 0 * 16 )) ); + + // c[0,0-15] + c_int16_0p0 = _mm256_add_epi16( selector1, c_int16_0p0 ); + + // c[1,0-15] + c_int16_1p0 = _mm256_add_epi16( selector1, c_int16_1p0 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_2x16: + { + selector1 = _mm256_setzero_si256 (); + + // c[0,0-15] + c_int16_0p0 = _mm256_max_epi16( selector1, c_int16_0p0 ); + + // c[1,0-15] + c_int16_1p0 = _mm256_max_epi16( selector1, c_int16_1p0 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_2x16: + { + selector2 = + _mm256_set1_epi16( *( ( int16_t* )post_ops_list_temp->op_args2 ) ); + + // c[0,0-15] + RELU_SCALE_OP_S16_AVX2(c_int16_0p0) + + // c[1,0-15] + RELU_SCALE_OP_S16_AVX2(c_int16_1p0) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_DOWNSCALE_2x16: + { + __m128i temp[2]; + __m256i temp_32[2]; + __m256 temp_float[2]; + __m256 scale_1, scale_2; + __m256 res_1, res_2; + __m256i store_reg; + + /* Load the scale vector values into the register*/ + scale_1 = + _mm256_loadu_ps( + (float *)post_ops_list_temp->scale_factor + + post_op_c_j + (0 * 8)); + scale_2 = + _mm256_loadu_ps( + (float *)post_ops_list_temp->scale_factor + + post_op_c_j + (1 * 8)); + + BLI_MM256_S16_DOWNSCALE2(c_int16_0p0, c_int16_1p0, 0, 1); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_2x16_DISABLE: + ; + + // Store the results. + // c[0,0-15] + _mm256_storeu_si256((__m256i *)(c + (rs_c * 0) + (0 * 16)), c_int16_0p0); + + // c[1,0-15] + _mm256_storeu_si256((__m256i *)(c + (rs_c * 1) + (0 * 16)), c_int16_1p0); +} + +// 2xlt16 int8o16 kernel +LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_2xlt16) +{ + dim_t NR = 16; + + static void *post_ops_labels[] = + { + &&POST_OPS_2xlt16_DISABLE, + &&POST_OPS_BIAS_2xlt16, + &&POST_OPS_RELU_2xlt16, + &&POST_OPS_RELU_SCALE_2xlt16, + &&POST_OPS_DOWNSCALE_2xlt16 + }; + + // The division is done by considering the vpmaddubsw instruction + dim_t k_full_pieces = k0 / 2; + dim_t k_partial_pieces = k0 % 2; + + // B matrix storage. + __m256i b0; + + // A matrix storage. + __m256i a_int32_0; + __m256i inter_vec; + + int16_t buf0[16]; + int16_t buf1[16]; + + // Registers to use for accumulating C. + __m256i c_int16_0p0 = _mm256_setzero_si256(); + + __m256i c_int16_1p0 = _mm256_setzero_si256(); + + for (dim_t kr = 0; kr < k_full_pieces; kr += 1) + { + dim_t offset = kr * 2; + + b0 = _mm256_loadu_si256((__m256i const *)(b + (32 * kr) + (NR * 0))); + + // Broadcast a[0,kr:kr+2]. + a_int32_0 = _mm256_set1_epi16(*(uint16_t *)(a + (rs_a * 0) + (cs_a * offset))); + + // Seperate register for intermediate op + inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); + + // Perform column direction mat-mul with k = 2. + // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_0p0 = _mm256_add_epi16(inter_vec, c_int16_0p0); + + // Broadcast a[1,kr:kr+2]. + a_int32_0 = _mm256_set1_epi16(*(uint16_t *)(a + (rs_a * 1) + (cs_a * offset))); + + // Seperate register for intermediate op + inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); + + // Perform column direction mat-mul with k = 4. + // c[1,0-15] = a[1,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_1p0 = _mm256_add_epi16(inter_vec, c_int16_1p0); + } + // Handle k remainder. + if (k_partial_pieces > 0) + { + uint8_t a_kfringe; + + b0 = _mm256_loadu_si256((__m256i const *)(b + (32 * k_full_pieces) + (NR * 0))); + + a_kfringe = *(a + (rs_a * 0) + (cs_a * (k_full_pieces * 2))); + a_int32_0 = _mm256_set1_epi8(a_kfringe); + + // Seperate register for intermediate op + inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); + + // Perform column direction mat-mul with k = 2. + // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_0p0 = _mm256_add_epi16(inter_vec, c_int16_0p0); + + a_kfringe = *(a + (rs_a * 1) + (cs_a * (k_full_pieces * 2))); + a_int32_0 = _mm256_set1_epi8(a_kfringe); + + // Seperate register for intermediate op + inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); + + // Perform column direction mat-mul with k = 2. + // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_1p0 = _mm256_add_epi16(inter_vec, c_int16_1p0); + } + + // Load alpha and beta + __m256i selector1 = _mm256_set1_epi16(alpha); + __m256i selector2 = _mm256_set1_epi16(beta); + + // Scale by alpha + c_int16_0p0 = _mm256_mullo_epi16(selector1, c_int16_0p0); + + c_int16_1p0 = _mm256_mullo_epi16(selector1, c_int16_1p0); + + // Scale C by beta. + if (beta != 0) + { + memcpy(buf0, (c + (rs_c * 0)), (n0_rem * sizeof(int16_t))); + memcpy(buf1, (c + (rs_c * 1)), (n0_rem * sizeof(int16_t))); + + // c[0,0-15] + selector1 = _mm256_loadu_si256((__m256i const *)buf0); + selector1 = _mm256_mullo_epi16(selector2, selector1); + c_int16_0p0 = _mm256_add_epi16(selector1, c_int16_0p0); + + // c[1,0-15] + selector1 = _mm256_loadu_si256((__m256i const *)buf1); + selector1 = _mm256_mullo_epi16(selector2, selector1); + c_int16_1p0 = _mm256_add_epi16(selector1, c_int16_1p0); + } + + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_2xlt16: + { + memcpy( buf0, ( ( int16_t* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 0 * 16 ) ), ( n0_rem * sizeof( int16_t ) ) ); + + selector1 = + _mm256_loadu_si256( (__m256i const *) buf0); + + // c[0,0-15] + c_int16_0p0 = _mm256_add_epi16( selector1, c_int16_0p0 ); + + // c[1,0-15] + c_int16_1p0 = _mm256_add_epi16( selector1, c_int16_1p0 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_2xlt16: + { + selector1 = _mm256_setzero_si256 (); + + // c[0,0-15] + c_int16_0p0 = _mm256_max_epi16( selector1, c_int16_0p0 ); + + // c[1,0-15] + c_int16_1p0 = _mm256_max_epi16( selector1, c_int16_1p0 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_2xlt16: + { + selector2 = + _mm256_set1_epi16( *( ( int16_t* )post_ops_list_temp->op_args2 ) ); + + // c[0,0-15] + RELU_SCALE_OP_S16_AVX2(c_int16_0p0) + + // c[1,0-15] + RELU_SCALE_OP_S16_AVX2(c_int16_1p0) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_DOWNSCALE_2xlt16: + { + __m128i temp[2]; + __m256i temp_32[2]; + __m256 temp_float[2]; + __m256 scale_1, scale_2; + __m256 res_1, res_2; + __m256i store_reg; + + float float_buf[16]; + int8_t store_buf[16]; + + memcpy( float_buf, ( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j ), ( n0_rem * sizeof( float ) ) ); + + // Load the scale vector values into the register + scale_1 = _mm256_loadu_ps(float_buf + (0 * 8)); + scale_2 = _mm256_loadu_ps(float_buf + (1 * 8)); + + BLI_MM256_S16_DOWNSCALE2_LT16(c_int16_0p0, c_int16_1p0, 0, 1) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_2xlt16_DISABLE: + ; + + // c[0,0-15] + _mm256_storeu_si256((__m256i_u *)buf0, c_int16_0p0); + + // c[1,0-15] + _mm256_storeu_si256((__m256i_u *)buf1, c_int16_1p0); + + // c[0,0-15] + memcpy(c + (rs_c * 0) + (0 * 16), buf0, (n0_rem * sizeof(int16_t))); + + // c[1,0-15] + memcpy(c + (rs_c * 1) + (0 * 16), buf1, (n0_rem * sizeof(int16_t))); +} + +// 1x16 int8o16 kernel +LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_1x16) +{ + int NR = 16; + + static void *post_ops_labels[] = + { + &&POST_OPS_1x16_DISABLE, + &&POST_OPS_BIAS_1x16, + &&POST_OPS_RELU_1x16, + &&POST_OPS_RELU_SCALE_1x16, + &&POST_OPS_DOWNSCALE_1x16 + }; + + // The division is done by considering the vpmaddubsw instruction + int k_full_pieces = k0 / 2; + int k_partial_pieces = k0 % 2; + + // B matrix storage. + __m256i b0; + + // A matrix storage. + __m256i a_int32_0; + __m256i inter_vec; + + // Registers to use for accumulating C. + __m256i c_int16_0p0 = _mm256_setzero_si256(); + + for (int kr = 0; kr < k_full_pieces; kr += 1) + { + int offset = kr * 2; + + // Broadcast a[0,kr:kr+2]. + a_int32_0 = _mm256_set1_epi16(*(uint16_t *)(a + (rs_a * 0) + (cs_a * offset))); + + b0 = _mm256_loadu_si256((__m256i const *)(b + (32 * kr) + (NR * 0))); + + // Seperate register for intermediate op + inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); + + // Perform column direction mat-mul with k = 2. + // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_0p0 = _mm256_add_epi16(inter_vec, c_int16_0p0); + } + // Handle k remainder. + if (k_partial_pieces > 0) + { + uint8_t a_kfringe; + + b0 = _mm256_loadu_si256((__m256i const *)(b + (32 * k_full_pieces) + (NR * 0))); + + a_kfringe = *(a + (rs_a * 0) + (cs_a * (k_full_pieces * 2))); + a_int32_0 = _mm256_set1_epi8(a_kfringe); + + // Seperate register for intermediate op + inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); + + // Perform column direction mat-mul with k = 2. + // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_0p0 = _mm256_add_epi16(inter_vec, c_int16_0p0); + } + + // Load alpha and beta + __m256i selector1 = _mm256_set1_epi16(alpha); + __m256i selector2 = _mm256_set1_epi16(beta); + + // Scale by alpha + c_int16_0p0 = _mm256_mullo_epi16(selector1, c_int16_0p0); + + // Scale C by beta. + if (beta != 0) + { + // c[0,0-15] + selector1 = _mm256_loadu_si256((__m256i const *)(c + (rs_c * 0) + (0 * 16))); + selector1 = _mm256_mullo_epi16(selector2, selector1); + c_int16_0p0 = _mm256_add_epi16(selector1, c_int16_0p0); + } + + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_1x16: + { + selector1 = + _mm256_loadu_si256( (__m256i const *)((int16_t *)post_ops_list_temp->op_args1 + + post_op_c_j + ( 0 * 16 )) ); + + // c[0,0-15] + c_int16_0p0 = _mm256_add_epi16( selector1, c_int16_0p0 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_1x16: + { + selector1 = _mm256_setzero_si256 (); + + // c[0,0-15] + c_int16_0p0 = _mm256_max_epi16( selector1, c_int16_0p0 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_1x16: + { + selector2 = + _mm256_set1_epi16( *( ( int16_t* )post_ops_list_temp->op_args2 ) ); + + // c[0,0-15] + RELU_SCALE_OP_S16_AVX2(c_int16_0p0) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_DOWNSCALE_1x16: + { + __m128i temp[2]; + __m256i temp_32[2], zero_reg; + __m256 temp_float[2]; + __m256 scale_1, scale_2; + __m256 res_1, res_2; + __m256i store_reg; + + /* Load the scale vector values into the register*/ + scale_1 = + _mm256_loadu_ps( + (float *)post_ops_list_temp->scale_factor + + post_op_c_j + (0 * 8)); + scale_2 = + _mm256_loadu_ps( + (float *)post_ops_list_temp->scale_factor + + post_op_c_j + (1 * 8)); + + zero_reg = _mm256_setzero_si256(); + + BLI_MM256_S16_DOWNSCALE2_EDGE(c_int16_0p0, 0) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_1x16_DISABLE: + ; + + // Store the results. + // c[0,0-15] + _mm256_storeu_si256( (__m256i *)(c + ( rs_c * 0 ) + ( 0*16 )), c_int16_0p0 ); +} + +// 1xlt16 int8o16 kernel +LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_1xlt16) +{ + int NR = 16; + + static void *post_ops_labels[] = + { + &&POST_OPS_1xlt16_DISABLE, + &&POST_OPS_BIAS_1xlt16, + &&POST_OPS_RELU_1xlt16, + &&POST_OPS_RELU_SCALE_1xlt16, + &&POST_OPS_DOWNSCALE_1xlt16 + }; + + // The division is done by considering the vpmaddubsw instruction + int k_full_pieces = k0 / 2; + int k_partial_pieces = k0 % 2; + + // B matrix storage. + __m256i b0; + + // A matrix storage. + __m256i a_int32_0; + __m256i inter_vec; + + int16_t buf0[16]; + + // Registers to use for accumulating C. + __m256i c_int16_0p0 = _mm256_setzero_si256(); + + for (int kr = 0; kr < k_full_pieces; kr += 1) + { + int offset = kr * 2; + + b0 = _mm256_loadu_si256((__m256i const *)(b + (32 * kr) + (NR * 0))); + + // Broadcast a[0,kr:kr+2]. + a_int32_0 = _mm256_set1_epi16(*(uint16_t *)(a + (rs_a * 0) + (cs_a * offset))); + + // Seperate register for intermediate op + inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); + + // Perform column direction mat-mul with k = 2. + // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_0p0 = _mm256_add_epi16(inter_vec, c_int16_0p0); + } + // Handle k remainder. + if (k_partial_pieces > 0) + { + uint8_t a_kfringe; + + b0 = _mm256_loadu_si256((__m256i const *)(b + (32 * k_full_pieces) + (NR * 0))); + + a_kfringe = *(a + (rs_a * 0) + (cs_a * (k_full_pieces * 2))); + a_int32_0 = _mm256_set1_epi8(a_kfringe); + + // Seperate register for intermediate op + inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); + + // Perform column direction mat-mul with k = 2. + // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_0p0 = _mm256_add_epi16(inter_vec, c_int16_0p0); + } + + // Load alpha and beta + __m256i selector1 = _mm256_set1_epi16(alpha); + __m256i selector2 = _mm256_set1_epi16(beta); + + // Scale by alpha + c_int16_0p0 = _mm256_mullo_epi16(selector1, c_int16_0p0); + + // Scale C by beta. + if (beta != 0) + { + memcpy(buf0, (c + (rs_c * 0)), (n0_rem * sizeof(int16_t))); + + // c[0,0-15] + selector1 = _mm256_loadu_si256((__m256i const *)buf0); + selector1 = _mm256_mullo_epi16(selector2, selector1); + c_int16_0p0 = _mm256_add_epi16(selector1, c_int16_0p0); + } + + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_1xlt16: + { + memcpy( buf0, ( ( int16_t* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 0 * 16 ) ), ( n0_rem * sizeof( int16_t ) ) ); + + selector1 = + _mm256_loadu_si256( (__m256i const *)buf0 ); + + // c[0,0-15] + c_int16_0p0 = _mm256_add_epi16( selector1, c_int16_0p0 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_1xlt16: + { + selector1 = _mm256_setzero_si256 (); + + // c[0,0-15] + c_int16_0p0 = _mm256_max_epi16( selector1, c_int16_0p0 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_1xlt16: + { + selector2 = + _mm256_set1_epi16( *( ( int16_t* )post_ops_list_temp->op_args2 ) ); + + // c[0,0-15] + RELU_SCALE_OP_S16_AVX2(c_int16_0p0) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_DOWNSCALE_1xlt16: + { + __m128i temp[2]; + __m256i temp_32[2], zero_reg; + __m256 temp_float[2]; + __m256 scale_1, scale_2; + __m256 res_1, res_2; + __m256i store_reg; + + float float_buf[16]; + int8_t store_buf[16]; + + memcpy( float_buf, ( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j ), ( n0_rem * sizeof( float ) ) ); + + // Load the scale vector values into the register + scale_1 = _mm256_loadu_ps(float_buf + (0 * 8)); + scale_2 = _mm256_loadu_ps(float_buf + (1 * 8)); + + zero_reg = _mm256_setzero_si256(); + + BLI_MM256_S16_DOWNSCALE2_EDGE_LT16(c_int16_0p0, 0) + } +POST_OPS_1xlt16_DISABLE: + ; + + // c[0,0-15] + _mm256_storeu_si256((__m256i_u *)buf0, c_int16_0p0); + + memcpy(c + (rs_c * 0) + (0 * 16), buf0, (n0_rem * sizeof(int16_t))); +} diff --git a/addon/aocl_gemm/kernels/u8s8s16/lpgemm_n_fringe_amd256.c b/addon/aocl_gemm/kernels/u8s8s16/lpgemm_n_fringe_amd256.c new file mode 100644 index 0000000000..b24d49dac7 --- /dev/null +++ b/addon/aocl_gemm/kernels/u8s8s16/lpgemm_n_fringe_amd256.c @@ -0,0 +1,915 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include + +#include "blis.h" +#include "lpgemm_kernels.h" +#include "lpgemm_s16_kern_macros.h" + +// 6x16 int8o16 kernel +LPGEMM_N_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_6x16) +{ + dim_t MR = 6; + dim_t NR = 16; + + static void *post_ops_labels[] = + { + &&POST_OPS_6x16_DISABLE, + &&POST_OPS_BIAS_6x16, + &&POST_OPS_RELU_6x16, + &&POST_OPS_RELU_SCALE_6x16, + &&POST_OPS_DOWNSCALE_6x16 + }; + + dim_t m_full_pieces = m0 / MR; + dim_t m_full_pieces_loop_limit = m_full_pieces * MR; + dim_t m_partial_pieces = m0 % MR; + + // The division is done by considering the vpmaddubsw instruction + dim_t k_full_pieces = k0 / 2; + dim_t k_partial_pieces = k0 % 2; + + // B matrix storage. + __m256i b0; + + // A matrix storage. + __m256i a_int32_0; + __m256i inter_vec; + + for (dim_t ir = 0; ir < m_full_pieces_loop_limit; ir += MR) + { + // Registers to use for accumulating C. + __m256i c_int16_0p0 = _mm256_setzero_si256(); + + __m256i c_int16_1p0 = _mm256_setzero_si256(); + + __m256i c_int16_2p0 = _mm256_setzero_si256(); + + __m256i c_int16_3p0 = _mm256_setzero_si256(); + + __m256i c_int16_4p0 = _mm256_setzero_si256(); + + __m256i c_int16_5p0 = _mm256_setzero_si256(); + + for (dim_t kr = 0; kr < k_full_pieces; kr += 1) + { + int offset = kr * 2; + + b0 = _mm256_loadu_si256((__m256i const *)(b + (32 * kr) + (NR * 0))); + + // Broadcast a[0,kr:kr+2]. + a_int32_0 = _mm256_set1_epi16(*(uint16_t *)(a + (rs_a * 0) + (cs_a * offset))); + + // Seperate register for intermediate op + inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); + + // Perform column direction mat-mul with k = 2. + // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_0p0 = _mm256_add_epi16(inter_vec, c_int16_0p0); + + // Broadcast a[1,kr:kr+2]. + a_int32_0 = _mm256_set1_epi16(*(uint16_t *)(a + (rs_a * 1) + (cs_a * offset))); + + // Seperate register for intermediate op + inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); + + // Perform column direction mat-mul with k = 2. + // c[1,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_1p0 = _mm256_add_epi16(inter_vec, c_int16_1p0); + + // Broadcast a[2,kr:kr+2]. + a_int32_0 = _mm256_set1_epi16(*(uint16_t *)(a + (rs_a * 2) + (cs_a * offset))); + + // Seperate register for intermediate op + inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); + + // Perform column direction mat-mul with k = 2. + // c[2,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_2p0 = _mm256_add_epi16(inter_vec, c_int16_2p0); + + // Broadcast a[3,kr:kr+2]. + a_int32_0 = _mm256_set1_epi16(*(uint16_t *)(a + (rs_a * 3) + (cs_a * offset))); + + // Seperate register for intermediate op + inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); + + // Perform column direction mat-mul with k = 2. + // c[3,0-15] = a[3,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_3p0 = _mm256_add_epi16(inter_vec, c_int16_3p0); + + // Broadcast a[4,kr:kr+2]. + a_int32_0 = _mm256_set1_epi16(*(uint16_t *)(a + (rs_a * 4) + (cs_a * offset))); + + // Seperate register for intermediate op + inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); + + // Perform column direction mat-mul with k = 2. + // c[4,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_4p0 = _mm256_add_epi16(inter_vec, c_int16_4p0); + + // Broadcast a[5,kr:kr+2]. + a_int32_0 = _mm256_set1_epi16(*(uint16_t *)(a + (rs_a * 5) + (cs_a * offset))); + + // Seperate register for intermediate op + inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); + + // Perform column direction mat-mul with k = 2. + // c[5,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_5p0 = _mm256_add_epi16(inter_vec, c_int16_5p0); + } + + // Handle k remainder. + if (k_partial_pieces > 0) + { + uint8_t a_kfringe; + + b0 = _mm256_loadu_si256((__m256i const *)(b + (32 * k_full_pieces) + (NR * 0))); + + a_kfringe = *(a + (rs_a * 0) + (cs_a * (k_full_pieces * 2))); + a_int32_0 = _mm256_set1_epi8(a_kfringe); + + // Seperate register for intermediate op + inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); + + // Perform column direction mat-mul with k = 2. + // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_0p0 = _mm256_add_epi16(inter_vec, c_int16_0p0); + + a_kfringe = *(a + (rs_a * 1) + (cs_a * (k_full_pieces * 2))); + a_int32_0 = _mm256_set1_epi8(a_kfringe); + + // Seperate register for intermediate op + inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); + + // Perform column direction mat-mul with k = 2. + // c[1,0-15] = a[1,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_1p0 = _mm256_add_epi16(inter_vec, c_int16_1p0); + + a_kfringe = *(a + (rs_a * 2) + (cs_a * (k_full_pieces * 2))); + a_int32_0 = _mm256_set1_epi8(a_kfringe); + + // Seperate register for intermediate op + inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); + + // Perform column direction mat-mul with k = 2. + // c[2,0-15] = a[2,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_2p0 = _mm256_add_epi16(inter_vec, c_int16_2p0); + + a_kfringe = *(a + (rs_a * 3) + (cs_a * (k_full_pieces * 2))); + a_int32_0 = _mm256_set1_epi8(a_kfringe); + + // Seperate register for intermediate op + inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); + + // Perform column direction mat-mul with k = 2. + // c[3,0-15] = a[3,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_3p0 = _mm256_add_epi16(inter_vec, c_int16_3p0); + + a_kfringe = *(a + (rs_a * 4) + (cs_a * (k_full_pieces * 2))); + a_int32_0 = _mm256_set1_epi8(a_kfringe); + + // Seperate register for intermediate op + inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); + + // Perform column direction mat-mul with k = 2. + // c[4,0-15] = a[4,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_4p0 = _mm256_add_epi16(inter_vec, c_int16_4p0); + + a_kfringe = *(a + (rs_a * 5) + (cs_a * (k_full_pieces * 2))); + a_int32_0 = _mm256_set1_epi8(a_kfringe); + + // Seperate register for intermediate op + inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); + + // Perform column direction mat-mul with k = 2. + // c[5,0-15] = a[5,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_5p0 = _mm256_add_epi16(inter_vec, c_int16_5p0); + } + + // Load alpha and beta + __m256i selector1 = _mm256_set1_epi16(alpha); + __m256i selector2 = _mm256_set1_epi16(beta); + + // Scale by alpha + c_int16_0p0 = _mm256_mullo_epi16(selector1, c_int16_0p0); + + c_int16_1p0 = _mm256_mullo_epi16(selector1, c_int16_1p0); + + c_int16_2p0 = _mm256_mullo_epi16(selector1, c_int16_2p0); + + c_int16_3p0 = _mm256_mullo_epi16(selector1, c_int16_3p0); + + c_int16_4p0 = _mm256_mullo_epi16(selector1, c_int16_4p0); + + c_int16_5p0 = _mm256_mullo_epi16(selector1, c_int16_5p0); + + // Scale C by beta. + if (beta != 0) + { + // c[0,0-15] + selector1 = _mm256_loadu_si256((__m256i const *)(c + (rs_c * (ir + 0)) + (0 * 16))); + selector1 = _mm256_mullo_epi16(selector2, selector1); + c_int16_0p0 = _mm256_add_epi16(selector1, c_int16_0p0); + + // c[1,0-15] + selector1 = _mm256_loadu_si256((__m256i const *)(c + (rs_c * (ir + 1)) + (0 * 16))); + selector1 = _mm256_mullo_epi16(selector2, selector1); + c_int16_1p0 = _mm256_add_epi16(selector1, c_int16_1p0); + + // c[2,0-15] + selector1 = _mm256_loadu_si256((__m256i const *)(c + (rs_c * (ir + 2)) + (0 * 16))); + selector1 = _mm256_mullo_epi16(selector2, selector1); + c_int16_2p0 = _mm256_add_epi16(selector1, c_int16_2p0); + + // c[3,0-15] + selector1 = _mm256_loadu_si256((__m256i const *)(c + (rs_c * (ir + 3)) + (0 * 16))); + selector1 = _mm256_mullo_epi16(selector2, selector1); + c_int16_3p0 = _mm256_add_epi16(selector1, c_int16_3p0); + + // c[4,0-15] + selector1 = _mm256_loadu_si256((__m256i const *)(c + (rs_c * (ir + 4)) + (0 * 16))); + selector1 = _mm256_mullo_epi16(selector2, selector1); + c_int16_4p0 = _mm256_add_epi16(selector1, c_int16_4p0); + + // c[5,0-15] + selector1 = _mm256_loadu_si256((__m256i const *)(c + (rs_c * (ir + 5)) + (0 * 16))); + selector1 = _mm256_mullo_epi16(selector2, selector1); + c_int16_5p0 = _mm256_add_epi16(selector1, c_int16_5p0); + } + + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_6x16: + { + selector1 = + _mm256_loadu_si256( (__m256i const *)((int16_t *)post_ops_list_temp->op_args1 + + post_op_c_j + ( 0 * 16 )) ); + + // c[0,0-15] + c_int16_0p0 = _mm256_add_epi16( selector1, c_int16_0p0 ); + + // c[1,0-15] + c_int16_1p0 = _mm256_add_epi16( selector1, c_int16_1p0 ); + + // c[2,0-15] + c_int16_2p0 = _mm256_add_epi16( selector1, c_int16_2p0 ); + + // c[3,0-15] + c_int16_3p0 = _mm256_add_epi16( selector1, c_int16_3p0 ); + + // c[4,0-15] + c_int16_4p0 = _mm256_add_epi16( selector1, c_int16_4p0 ); + + // c[5,0-15] + c_int16_5p0 = _mm256_add_epi16( selector1, c_int16_5p0 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_6x16: + { + selector1 = _mm256_setzero_si256 (); + + // c[0,0-15] + c_int16_0p0 = _mm256_max_epi16( selector1, c_int16_0p0 ); + + // c[1,0-15] + c_int16_1p0 = _mm256_max_epi16( selector1, c_int16_1p0 ); + + // c[2,0-15] + c_int16_2p0 = _mm256_max_epi16( selector1, c_int16_2p0 ); + + // c[3,0-15] + c_int16_3p0 = _mm256_max_epi16( selector1, c_int16_3p0 ); + + // c[4,0-15] + c_int16_4p0 = _mm256_max_epi16( selector1, c_int16_4p0 ); + + // c[5,0-15] + c_int16_5p0 = _mm256_max_epi16( selector1, c_int16_5p0 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_6x16: + { + selector2 = + _mm256_set1_epi16( *( ( int16_t* )post_ops_list_temp->op_args2 ) ); + + // c[0,0-15] + RELU_SCALE_OP_S16_AVX2(c_int16_0p0) + + // c[1,0-15] + RELU_SCALE_OP_S16_AVX2(c_int16_1p0) + + // c[2,0-15] + RELU_SCALE_OP_S16_AVX2(c_int16_2p0) + + // c[3,0-15] + RELU_SCALE_OP_S16_AVX2(c_int16_3p0) + + // c[4,0-15] + RELU_SCALE_OP_S16_AVX2(c_int16_4p0) + + // c[5,0-15] + RELU_SCALE_OP_S16_AVX2(c_int16_5p0) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_DOWNSCALE_6x16: + { + __m128i temp[2]; + __m256i temp_32[2]; + __m256 temp_float[2]; + __m256 scale_1, scale_2; + __m256 res_1, res_2; + __m256i store_reg; + + /* Load the scale vector values into the register*/ + scale_1 = + _mm256_loadu_ps( + (float *)post_ops_list_temp->scale_factor + + post_op_c_j + (0 * 8)); + scale_2 = + _mm256_loadu_ps( + (float *)post_ops_list_temp->scale_factor + + post_op_c_j + (1 * 8)); + + BLI_MM256_S16_DOWNSCALE2(c_int16_0p0, c_int16_1p0, 0, 1); + + BLI_MM256_S16_DOWNSCALE2(c_int16_2p0, c_int16_3p0, 2, 3); + + BLI_MM256_S16_DOWNSCALE2(c_int16_4p0, c_int16_5p0, 4, 5); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_6x16_DISABLE: + ; + + // Store the results. + // c[0,0-15] + _mm256_storeu_si256( (__m256i *)(c + ( rs_c * ( ir + 0 ) ) + ( 0 * 16 ) ), c_int16_0p0 ); + + // c[1,0-15] + _mm256_storeu_si256( (__m256i *)(c + ( rs_c * ( ir + 1 ) ) + ( 0 * 16 ) ), c_int16_1p0 ); + + // c[2,0-15] + _mm256_storeu_si256( (__m256i *)(c + ( rs_c * ( ir + 2 ) ) + ( 0 * 16 ) ), c_int16_2p0 ); + + // c[3,0-15] + _mm256_storeu_si256( (__m256i *)(c + ( rs_c * ( ir + 3 ) ) + ( 0 * 16 ) ), c_int16_3p0 ); + + // c[4,0-15] + _mm256_storeu_si256( (__m256i *)(c + ( rs_c * ( ir + 4 ) ) + ( 0 * 16 ) ), c_int16_4p0 ); + + // c[5,0-15] + _mm256_storeu_si256( (__m256i *)(c + ( rs_c * ( ir + 5 ) ) + ( 0 * 16 ) ), c_int16_5p0 ); + + a = a + ( MR * ps_a ); + post_op_c_i += MR; + } + + if (m_partial_pieces > 0) + { + dim_t m_partial4 = m_partial_pieces / 4; + m_partial_pieces = m_partial_pieces % 4; + + dim_t m_partial2 = m_partial_pieces / 2; + dim_t m_partial = m_partial_pieces % 2; + + if (m_partial4 == 1) + { + lpgemm_rowvar_u8s8s16o16_4x16( + k0, + a, rs_a, cs_a, + b, rs_b, cs_b, + (c + (rs_c * m_full_pieces_loop_limit)), rs_c, + alpha, beta, + is_last_k, + post_op_c_i, post_op_c_j, + post_ops_list, rs_c_downscale); + + // a pointer increment + a = a + (4 * ps_a); + m_full_pieces_loop_limit += 4; + post_op_c_i += 4; + } + + if (m_partial2 == 1) + { + lpgemm_rowvar_u8s8s16o16_2x16( + k0, + a, rs_a, cs_a, + b, rs_b, cs_b, + (c + (rs_c * m_full_pieces_loop_limit)), rs_c, + alpha, beta, + is_last_k, + post_op_c_i, post_op_c_j, + post_ops_list, rs_c_downscale); + + // a pointer increment + a = a + (2 * ps_a); + m_full_pieces_loop_limit += 2; + post_op_c_i += 2; + } + + if (m_partial == 1) + { + lpgemm_rowvar_u8s8s16o16_1x16( + k0, + a, rs_a, cs_a, + b, rs_b, cs_b, + (c + (rs_c * m_full_pieces_loop_limit)), rs_c, + alpha, beta, + is_last_k, + post_op_c_i, post_op_c_j, + post_ops_list, rs_c_downscale); + post_op_c_i += 1; + } + } +} + +// 6xlt16 int8o16 kernel +LPGEMM_N_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_6xlt16) +{ + dim_t MR = 6; + + static void *post_ops_labels[] = + { + &&POST_OPS_6xlt16_DISABLE, + &&POST_OPS_BIAS_6xlt16, + &&POST_OPS_RELU_6xlt16, + &&POST_OPS_RELU_SCALE_6xlt16, + &&POST_OPS_DOWNSCALE_6xlt16 + }; + + dim_t m_full_pieces = m0 / MR; + dim_t m_full_pieces_loop_limit = m_full_pieces * MR; + dim_t m_partial_pieces = m0 % MR; + + // The division is done by considering the vpmaddubsw instruction + dim_t k_full_pieces = k0 / 2; + dim_t k_partial_pieces = k0 % 2; + + int16_t buf0[16]; + int16_t buf1[16]; + int16_t buf2[16]; + int16_t buf3[16]; + int16_t buf4[16]; + int16_t buf5[16]; + + // B matrix storage. + __m256i b0; + + // A matrix storage. + __m256i a_int32_0; + __m256i inter_vec; + + for (dim_t ir = 0; ir < m_full_pieces_loop_limit; ir += MR) + { + // Registers to use for accumulating C. + __m256i c_int16_0p0 = _mm256_setzero_si256(); + + __m256i c_int16_1p0 = _mm256_setzero_si256(); + + __m256i c_int16_2p0 = _mm256_setzero_si256(); + + __m256i c_int16_3p0 = _mm256_setzero_si256(); + + __m256i c_int16_4p0 = _mm256_setzero_si256(); + + __m256i c_int16_5p0 = _mm256_setzero_si256(); + + for (dim_t kr = 0; kr < k_full_pieces; kr += 1) + { + dim_t offset = kr * 2; + + b0 = _mm256_loadu_si256((__m256i const *)(b + (32 * kr) + (cs_b * 0))); + + // Broadcast a[0,kr:kr+2]. + a_int32_0 = _mm256_set1_epi16(*(uint16_t *)(a + (rs_a * 0) + (cs_a * offset))); + + // Seperate register for intermediate op + inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); + + // Perform column direction mat-mul with k = 2. + // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_0p0 = _mm256_add_epi16(inter_vec, c_int16_0p0); + + // Broadcast a[1,kr:kr+2]. + a_int32_0 = _mm256_set1_epi16(*(uint16_t *)(a + (rs_a * 1) + (cs_a * offset))); + + // Seperate register for intermediate op + inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); + + // Perform column direction mat-mul with k = 2. + // c[1,0-15] = a[1,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_1p0 = _mm256_add_epi16(inter_vec, c_int16_1p0); + + // Broadcast a[2,kr:kr+2]. + a_int32_0 = _mm256_set1_epi16(*(uint16_t *)(a + (rs_a * 2) + (cs_a * offset))); + + // Seperate register for intermediate op + inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); + + // Perform column direction mat-mul with k = 2. + // c[2,0-15] = a[2,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_2p0 = _mm256_add_epi16(inter_vec, c_int16_2p0); + + // Broadcast a[3,kr:kr+2]. + a_int32_0 = _mm256_set1_epi16(*(uint16_t *)(a + (rs_a * 3) + (cs_a * offset))); + + // Seperate register for intermediate op + inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); + + // Perform column direction mat-mul with k = 2. + // c[3,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_3p0 = _mm256_add_epi16(inter_vec, c_int16_3p0); + + // Broadcast a[4,kr:kr+2]. + a_int32_0 = _mm256_set1_epi16(*(uint16_t *)(a + (rs_a * 4) + (cs_a * offset))); + + // Seperate register for intermediate op + inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); + + // Perform column direction mat-mul with k = 2. + // c[4,0-15] = a[4,kr:kr+2]*b[kr:kr+4,0-31] + c_int16_4p0 = _mm256_add_epi16(inter_vec, c_int16_4p0); + + // Broadcast a[5,kr:kr+4]. + a_int32_0 = _mm256_set1_epi16(*(uint16_t *)(a + (rs_a * 5) + (cs_a * offset))); + + // Seperate register for intermediate op + inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); + + // Perform column direction mat-mul with k = 2. + // c[5,0-15] = a[5,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_5p0 = _mm256_add_epi16(inter_vec, c_int16_5p0); + } + + // Handle k remainder. + if (k_partial_pieces > 0) + { + uint8_t a_kfringe; + + b0 = _mm256_loadu_si256((__m256i const *)(b + (32 * k_full_pieces) + (cs_b * 0))); + + a_kfringe = *(a + (rs_a * 0) + (cs_a * (k_full_pieces * 2))); + a_int32_0 = _mm256_set1_epi8(a_kfringe); + + inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); + + // Perform column direction mat-mul with k = 2. + // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_0p0 = _mm256_add_epi16(inter_vec, c_int16_0p0); + + a_kfringe = *(a + (rs_a * 1) + (cs_a * (k_full_pieces * 2))); + a_int32_0 = _mm256_set1_epi8(a_kfringe); + + // Seperate register for intermediate op + inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); + + // Perform column direction mat-mul with k = 2. + // c[1,0-15] = a[1,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_1p0 = _mm256_add_epi16(inter_vec, c_int16_1p0); + + a_kfringe = *(a + (rs_a * 2) + (cs_a * (k_full_pieces * 2))); + a_int32_0 = _mm256_set1_epi8(a_kfringe); + + // Seperate register for intermediate op + inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); + + // Perform column direction mat-mul with k = 2. + // c[2,0-15] = a[2,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_2p0 = _mm256_add_epi16(inter_vec, c_int16_2p0); + + a_kfringe = *(a + (rs_a * 3) + (cs_a * (k_full_pieces * 2))); + a_int32_0 = _mm256_set1_epi8(a_kfringe); + + // Seperate register for intermediate op + inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); + + // Perform column direction mat-mul with k = 2. + // c[3,0-15] = a[3,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_3p0 = _mm256_add_epi16(inter_vec, c_int16_3p0); + + a_kfringe = *(a + (rs_a * 4) + (cs_a * (k_full_pieces * 2))); + a_int32_0 = _mm256_set1_epi8(a_kfringe); + + // Seperate register for intermediate op + inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); + + // Perform column direction mat-mul with k = 2. + // c[4,0-15] = a[4,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_4p0 = _mm256_add_epi16(inter_vec, c_int16_4p0); + + a_kfringe = *(a + (rs_a * 5) + (cs_a * (k_full_pieces * 2))); + a_int32_0 = _mm256_set1_epi8(a_kfringe); + + // Seperate register for intermediate op + inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); + + // Perform column direction mat-mul with k = 2. + // c[5,0-15] = a[5,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_5p0 = _mm256_add_epi16(inter_vec, c_int16_5p0); + } + + // Load alpha and beta + __m256i selector1 = _mm256_set1_epi16(alpha); + __m256i selector2 = _mm256_set1_epi16(beta); + + // Scale by alpha + c_int16_0p0 = _mm256_mullo_epi16(selector1, c_int16_0p0); + + c_int16_1p0 = _mm256_mullo_epi16(selector1, c_int16_1p0); + + c_int16_2p0 = _mm256_mullo_epi16(selector1, c_int16_2p0); + + c_int16_3p0 = _mm256_mullo_epi16(selector1, c_int16_3p0); + + c_int16_4p0 = _mm256_mullo_epi16(selector1, c_int16_4p0); + + c_int16_5p0 = _mm256_mullo_epi16(selector1, c_int16_5p0); + + // Scale C by beta. + if (beta != 0) + { + memcpy(buf0, (c + (rs_c * (ir + 0))), (n0_rem * sizeof(int16_t))); + memcpy(buf1, (c + (rs_c * (ir + 1))), (n0_rem * sizeof(int16_t))); + memcpy(buf2, (c + (rs_c * (ir + 2))), (n0_rem * sizeof(int16_t))); + memcpy(buf3, (c + (rs_c * (ir + 3))), (n0_rem * sizeof(int16_t))); + memcpy(buf4, (c + (rs_c * (ir + 4))), (n0_rem * sizeof(int16_t))); + memcpy(buf5, (c + (rs_c * (ir + 5))), (n0_rem * sizeof(int16_t))); + + // c[0,0-15] + selector1 = _mm256_loadu_si256((__m256i const *)buf0); + selector1 = _mm256_mullo_epi16(selector2, selector1); + c_int16_0p0 = _mm256_add_epi16(selector1, c_int16_0p0); + + // c[1,0-15] + selector1 = _mm256_loadu_si256((__m256i const *)buf1); + selector1 = _mm256_mullo_epi16(selector2, selector1); + c_int16_1p0 = _mm256_add_epi16(selector1, c_int16_1p0); + + // c[2,0-15] + selector1 = _mm256_loadu_si256((__m256i const *)buf2); + selector1 = _mm256_mullo_epi16(selector2, selector1); + c_int16_2p0 = _mm256_add_epi16(selector1, c_int16_2p0); + + // c[3,0-15] + selector1 = _mm256_loadu_si256((__m256i const *)buf3); + selector1 = _mm256_mullo_epi16(selector2, selector1); + c_int16_3p0 = _mm256_add_epi16(selector1, c_int16_3p0); + + // c[4,0-15] + selector1 = _mm256_loadu_si256((__m256i const *)buf4); + selector1 = _mm256_mullo_epi16(selector2, selector1); + c_int16_4p0 = _mm256_add_epi16(selector1, c_int16_4p0); + + // c[5,0-15] + selector1 = _mm256_loadu_si256((__m256i const *)buf5); + selector1 = _mm256_mullo_epi16(selector2, selector1); + c_int16_5p0 = _mm256_add_epi16(selector1, c_int16_5p0); + } + + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_6xlt16: + { + memcpy( buf0, ( ( int16_t* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 0 * 16 ) ), + ( n0_rem * sizeof( int16_t ) ) ); + + selector1 = + _mm256_loadu_si256( ( __m256i const* )buf0 ); + + // c[0,0-15] + c_int16_0p0 = _mm256_add_epi16( selector1, c_int16_0p0 ); + + // c[1,0-15] + c_int16_1p0 = _mm256_add_epi16( selector1, c_int16_1p0 ); + + // c[2,0-15] + c_int16_2p0 = _mm256_add_epi16( selector1, c_int16_2p0 ); + + // c[3,0-15] + c_int16_3p0 = _mm256_add_epi16( selector1, c_int16_3p0 ); + + // c[4,0-15] + c_int16_4p0 = _mm256_add_epi16( selector1, c_int16_4p0 ); + + // c[5,0-15] + c_int16_5p0 = _mm256_add_epi16( selector1, c_int16_5p0 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_6xlt16: + { + selector1 = _mm256_setzero_si256 (); + + // c[0,0-15] + c_int16_0p0 = _mm256_max_epi16( selector1, c_int16_0p0 ); + + // c[1,0-15] + c_int16_1p0 = _mm256_max_epi16( selector1, c_int16_1p0 ); + + // c[2,0-15] + c_int16_2p0 = _mm256_max_epi16( selector1, c_int16_2p0 ); + + // c[3,0-15] + c_int16_3p0 = _mm256_max_epi16( selector1, c_int16_3p0 ); + + // c[4,0-15] + c_int16_4p0 = _mm256_max_epi16( selector1, c_int16_4p0 ); + + // c[5,0-15] + c_int16_5p0 = _mm256_max_epi16( selector1, c_int16_5p0 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_6xlt16: + { + selector2 = + _mm256_set1_epi16( *( ( int16_t* )post_ops_list_temp->op_args2 ) ); + + // c[0,0-15] + RELU_SCALE_OP_S16_AVX2(c_int16_0p0) + + // c[1,0-15] + RELU_SCALE_OP_S16_AVX2(c_int16_1p0) + + // c[2,0-15] + RELU_SCALE_OP_S16_AVX2(c_int16_2p0) + + // c[3,0-15] + RELU_SCALE_OP_S16_AVX2(c_int16_3p0) + + // c[4,0-15] + RELU_SCALE_OP_S16_AVX2(c_int16_4p0) + + // c[5,0-15] + RELU_SCALE_OP_S16_AVX2(c_int16_5p0) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } + POST_OPS_DOWNSCALE_6xlt16: + { + __m128i temp[2]; + __m256i temp_32[2]; + __m256 temp_float[2]; + __m256 scale_1, scale_2; + __m256 res_1, res_2; + __m256i store_reg; + + float float_buf[16]; + int8_t store_buf[16]; + + memcpy( float_buf, ( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j ), ( n0_rem * sizeof( float ) ) ); + + // Load the scale vector values into the register + scale_1 = _mm256_loadu_ps(float_buf + (0 * 8)); + scale_2 = _mm256_loadu_ps(float_buf + (1 * 8)); + + BLI_MM256_S16_DOWNSCALE2_LT16(c_int16_0p0, c_int16_1p0, 0, 1) + + BLI_MM256_S16_DOWNSCALE2_LT16(c_int16_2p0, c_int16_3p0, 2, 3) + + BLI_MM256_S16_DOWNSCALE2_LT16(c_int16_4p0, c_int16_5p0, 4, 5) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_6xlt16_DISABLE: + ; + + // Store the results. + // c[0,0-15] + _mm256_storeu_si256((__m256i *)buf0, c_int16_0p0); + + // c[1,0-15] + _mm256_storeu_si256((__m256i *)buf1, c_int16_1p0); + + // c[2,0-15] + _mm256_storeu_si256((__m256i *)buf2, c_int16_2p0); + + // c[3,0-15] + _mm256_storeu_si256((__m256i *)buf3, c_int16_3p0); + + // c[4,0-15] + _mm256_storeu_si256((__m256i *)buf4, c_int16_4p0); + + // c[5,0-15] + _mm256_storeu_si256((__m256i *)buf5, c_int16_5p0); + + memcpy(c + (rs_c * (ir + 0)) + (0 * 16), buf0, (n0_rem * sizeof(int16_t))); + + // c[1,0-15] + memcpy(c + (rs_c * (ir + 1)) + (0 * 16), buf1, (n0_rem * sizeof(int16_t))); + + // c[2,0-15] + memcpy(c + (rs_c * (ir + 2)) + (0 * 16), buf2, (n0_rem * sizeof(int16_t))); + + // c[3,0-15] + memcpy(c + (rs_c * (ir + 3)) + (0 * 16), buf3, (n0_rem * sizeof(int16_t))); + + // c[4,0-15] + memcpy(c + (rs_c * (ir + 4)) + (0 * 16), buf4, (n0_rem * sizeof(int16_t))); + + // c[5,0-15] + memcpy(c + (rs_c * (ir + 5)) + (0 * 16), buf5, (n0_rem * sizeof(int16_t))); + + a = a + (MR * ps_a); + post_op_c_i += MR; + } + + if (m_partial_pieces > 0) + { + dim_t m_partial4 = m_partial_pieces / 4; + m_partial_pieces = m_partial_pieces % 4; + + dim_t m_partial2 = m_partial_pieces / 2; + dim_t m_partial = m_partial_pieces % 2; + + if (m_partial4 == 1) + { + lpgemm_rowvar_u8s8s16o16_4xlt16( + k0, + a, rs_a, cs_a, + b, rs_b, cs_b, + (c + (rs_c * m_full_pieces_loop_limit)), rs_c, + alpha, beta, n0_rem, + is_last_k, + post_op_c_i, post_op_c_j, + post_ops_list, rs_c_downscale); + + // a pointer increment + a = a + (4 * ps_a); + m_full_pieces_loop_limit += 4; + post_op_c_i += 4; + } + + if (m_partial2 == 1) + { + lpgemm_rowvar_u8s8s16o16_2xlt16( + k0, + a, rs_a, cs_a, + b, rs_b, cs_b, + (c + (rs_c * m_full_pieces_loop_limit)), rs_c, + alpha, beta, n0_rem, + is_last_k, + post_op_c_i, post_op_c_j, + post_ops_list, rs_c_downscale); + + // a pointer increment + a = a + (2 * ps_a); + m_full_pieces_loop_limit += 2; + post_op_c_i += 2; + } + + if (m_partial == 1) + { + lpgemm_rowvar_u8s8s16o16_1xlt16( + k0, + a, rs_a, cs_a, + b, rs_b, cs_b, + (c + (rs_c * m_full_pieces_loop_limit)), rs_c, + alpha, beta, n0_rem, + is_last_k, + post_op_c_i, post_op_c_j, + post_ops_list, rs_c_downscale); + post_op_c_i += 1; + } + } +} diff --git a/addon/aocl_gemm/kernels/u8s8s16/lpgemm_packb_amd256.c b/addon/aocl_gemm/kernels/u8s8s16/lpgemm_packb_amd256.c new file mode 100644 index 0000000000..ac9cb469e3 --- /dev/null +++ b/addon/aocl_gemm/kernels/u8s8s16/lpgemm_packb_amd256.c @@ -0,0 +1,268 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include + +#include "blis.h" +#include "lpgemm_packb_s16.h" +#include "lpgemm_config.h" + +void get_packb_nr32_u8s8s16o16_strides + ( + dim_t *rs_b, + dim_t *cs_b + ) +{ + *rs_b = lpgemm_get_block_size_NR_global_cntx( U8S8S16OS16 ) * 2; + *cs_b = lpgemm_get_block_size_NR_global_cntx( U8S8S16OS16 ); +} + +void packb_nrlt16_u8s8s16o16 + ( + int8_t *pack_b_buffer_u8s8s16o16, + const int8_t *b, + const dim_t ldb, + const dim_t rows, + dim_t n0_partial_rem + ) +{ + dim_t k_full_pieces_blks = rows / 2; + dim_t k_full_pieces = k_full_pieces_blks * 2; + dim_t k_partial_pieces = rows % 2; + dim_t NR = 16; + dim_t kr_new = 0; + + int8_t buf0[16], buf1[16]; + + __m128i b_vec[2], inter_vec[2]; + + for (dim_t kr = 0; kr < k_full_pieces; kr += 2) + { + memcpy(buf0, (b + (ldb * (kr + 0))), (n0_partial_rem * sizeof(int8_t))); + memcpy(buf1, (b + (ldb * (kr + 1))), (n0_partial_rem * sizeof(int8_t))); + + // Read b[0,0], b[0,1], b[0,2]......., b[0,15] + b_vec[0] = _mm_loadu_si128((__m128i *)buf0); + // Read b[1,0], b[1,1], b[1,2]......., b[1,15] + b_vec[1] = _mm_loadu_si128((__m128i *)buf1); + + // Reorder B matrix inputs to suit vpmaddubsw instructions + inter_vec[0] = _mm_unpacklo_epi8(b_vec[0], b_vec[1]); + inter_vec[1] = _mm_unpackhi_epi8(b_vec[0], b_vec[1]); + + // Store b[0,0], b[1,0], b[0,1]......., b[0,7], b[1,7] + _mm_storeu_si128((__m128i *)(pack_b_buffer_u8s8s16o16 + (kr_new * NR)), inter_vec[0]); + // Store b[0,8], b[1,8], b[0,9]......., b[0,15], b[1,15] + _mm_storeu_si128((__m128i *)(pack_b_buffer_u8s8s16o16 + ((kr_new + 1) * NR)), inter_vec[1]); + + // Increment to ignore the padded bits + kr_new += 2; + } + + // Handle k partial cases + if (k_partial_pieces > 0) + { + memcpy(buf0, (b + (ldb * (k_full_pieces + 0))), (n0_partial_rem * sizeof(int8_t))); + + // Read b[0,0], b[0,1], b[0,2]......., b[0,15] + b_vec[0] = _mm_loadu_si128((__m128i *)buf0); + b_vec[1] = _mm_setzero_si128(); // Initialize with zero for padding + + // Reorder B matrix inputs to suit vpmaddubsw instructions + inter_vec[0] = _mm_unpacklo_epi8(b_vec[0], b_vec[1]); + inter_vec[1] = _mm_unpackhi_epi8(b_vec[0], b_vec[1]); + + // Store b[0,0], 0, b[0,1]......., b[0,7], 0 + _mm_storeu_si128((__m128i *)(pack_b_buffer_u8s8s16o16 + ((kr_new + 0) * NR)), inter_vec[0]); + + // Store b[0,8], 0, b[0,9]......., b[0,15], 0 + _mm_storeu_si128((__m128i *)(pack_b_buffer_u8s8s16o16 + ((kr_new + 1) * NR)), inter_vec[1]); + } +} + +void packb_nr16_u8s8s16o16( + int8_t *pack_b_buffer_u8s8s16o16, + const int8_t *b, + const dim_t ldb, + const dim_t rows) +{ + dim_t k_full_pieces_blks = rows / 2; + dim_t k_full_pieces = k_full_pieces_blks * 2; + dim_t k_partial_pieces = rows % 2; + dim_t NR = 16; + dim_t kr_new = 0; + + __m128i b_vec[2], inter_vec[2]; + + for (dim_t kr = 0; kr < k_full_pieces; kr += 2) + { + // Read b[0,0], b[0,1], b[0,2]......., b[0,15] + b_vec[0] = _mm_loadu_si128((__m128i const *)(b + (ldb * (kr + 0)))); + + // Read b[1,0], b[1,1], b[1,2]......., b[1,15] + b_vec[1] = _mm_loadu_si128((__m128i const *)(b + (ldb * (kr + 1)))); + + // Reorder B matrix inputs to suit vpmaddubsw instructions + inter_vec[0] = _mm_unpacklo_epi8(b_vec[0], b_vec[1]); + inter_vec[1] = _mm_unpackhi_epi8(b_vec[0], b_vec[1]); + + // Store b[0,0], b[1,0], b[0,1]......., b[0,7], b[1,7] + _mm_storeu_si128((__m128i *)(pack_b_buffer_u8s8s16o16 + ((kr_new + 0) * NR)), inter_vec[0]); + + // Store b[0,8], b[1,8], b[0,9]......., b[0,15], b[1,15] + _mm_storeu_si128((__m128i *)(pack_b_buffer_u8s8s16o16 + ((kr_new + 1) * NR)), inter_vec[1]); + + // Increment to ignore the padded bits + kr_new += 2; + } + + if (k_partial_pieces > 0) + { + // Read b[0,0], b[0,1], b[0,2]......., b[0,15] + b_vec[0] = _mm_loadu_si128((__m128i const *)(b + (ldb * (k_full_pieces + 0)))); + b_vec[1] = _mm_setzero_si128(); // Initialize with zero for padding + + // Reorder B matrix inputs to suit vpmaddubsw instructions + inter_vec[0] = _mm_unpacklo_epi8(b_vec[0], b_vec[1]); + inter_vec[1] = _mm_unpackhi_epi8(b_vec[0], b_vec[1]); + + // Store b[0,0], 0, b[0,1]......., b[0,7], 0 + _mm_storeu_si128((__m128i *)(pack_b_buffer_u8s8s16o16 + ((kr_new + 0) * NR)), inter_vec[0]); + // Store b[0,8], 0, b[0,9]......., b[0,15], 0 + _mm_storeu_si128((__m128i *)(pack_b_buffer_u8s8s16o16 + ((kr_new + 1) * NR)), inter_vec[1]); + } +} + +void packb_nr32_u8s8s16o16( + int8_t *pack_b_buffer_u8s8s16o16, + const int8_t *b, + const dim_t ldb, + const dim_t cols, + const dim_t rows, + dim_t *rs_b, + dim_t *cs_b) +{ + dim_t NR = 32; + + dim_t n_full_pieces = cols / NR; + dim_t n_full_pieces_loop_limit = n_full_pieces * NR; + dim_t n_partial_pieces = cols % NR; + dim_t k_full_pieces_blks = rows / 2; + dim_t k_full_pieces = k_full_pieces_blks * 2; + dim_t k_partial_pieces = rows % 2; + + dim_t KC_updated = rows; + + // Making multiple of 2 to suit k in vpmaddubsw + KC_updated += (KC_updated & 0x1); + + __m256i b_vec[2], inter_vec[2]; + + for (dim_t jc = 0; jc < n_full_pieces_loop_limit; jc += NR) + { + for (dim_t kr = 0; kr < k_full_pieces; kr += 2) + { + // Read b[0,0], b[0,1], b[0,2]......., b[0,31] + b_vec[0] = _mm256_loadu_si256((__m256i const *)(b + (ldb * (kr + 0)) + jc)); + + // Read b[1,0], b[1,1], b[1,2]......., b[1,31] + b_vec[1] = _mm256_loadu_si256((__m256i const *)(b + (ldb * (kr + 1)) + jc)); + + // Reorder B matrix inputs to suit vpmaddubsw instructions + inter_vec[0] = _mm256_unpacklo_epi8(b_vec[0], b_vec[1]); + inter_vec[1] = _mm256_unpackhi_epi8(b_vec[0], b_vec[1]); + + b_vec[0] = _mm256_permute2f128_si256(inter_vec[0], inter_vec[1], 0x20); + b_vec[1] = _mm256_permute2f128_si256(inter_vec[0], inter_vec[1], 0x31); + + // Store B[0,0], B[1,0], B[0,1], B[1,1], ......, B[0,15], B[1,15] + _mm256_storeu_si256((__m256i *)(pack_b_buffer_u8s8s16o16 + ((jc * KC_updated) + (kr * NR))), b_vec[0]); + // Store B[0,16], B[1,16], B[0,17], B[1,17], ......, B[0,31], B[1,31] + _mm256_storeu_si256((__m256i *)(pack_b_buffer_u8s8s16o16 + ((jc * KC_updated) + ((kr + 1) * NR))), b_vec[1]); + } + + if (k_partial_pieces > 0) + { + // Read b[0,0], b[0,1], b[0,2]......., b[0,31] + b_vec[0] = _mm256_loadu_si256((__m256i const *)(b + (ldb * (k_full_pieces + 0)) + jc)); + b_vec[1] = _mm256_setzero_si256(); // Initialize with zero for padding + + // Reorder B matrix inputs to suit vpmaddubsw instructions + inter_vec[0] = _mm256_unpacklo_epi8(b_vec[0], b_vec[1]); + inter_vec[1] = _mm256_unpackhi_epi8(b_vec[0], b_vec[1]); + + b_vec[0] = _mm256_permute2f128_si256(inter_vec[0], inter_vec[1], 0x20); + b_vec[1] = _mm256_permute2f128_si256(inter_vec[0], inter_vec[1], 0x31); + + // Store B[0,0], B[1,0], B[0,1], B[1,1], ......, B[0,15], B[1,15] + _mm256_storeu_si256((__m256i *)(pack_b_buffer_u8s8s16o16 + ((jc * KC_updated) + (k_full_pieces * NR))), b_vec[0]); + // Store B[0,16], B[1,16], B[0,17], B[1,17], ......, B[0,31], B[1,31] + _mm256_storeu_si256((__m256i *)(pack_b_buffer_u8s8s16o16 + ((jc * KC_updated) + ((k_full_pieces + 1) * NR))), b_vec[1]); + } + } + + // B matrix packing when n < NR + if (n_partial_pieces > 0) + { + // Split into multiple smaller fringe kernels, so as to maximize + // vectorization after packing. Any n0 < NR(32) can be expressed + // as n0 = 16 + n`. + dim_t n0_16 = n_partial_pieces / 16; + dim_t n0_partial_rem = n_partial_pieces % 16; + + dim_t n0_partial_pack = 0; + + if (n0_16 == 1) + { + packb_nr16_u8s8s16o16( + (pack_b_buffer_u8s8s16o16 + + (n_full_pieces_loop_limit * KC_updated)), + (b + n_full_pieces_loop_limit), ldb, rows); + + n0_partial_pack = 16; + } + + if (n0_partial_rem > 0) + { + packb_nrlt16_u8s8s16o16( + (pack_b_buffer_u8s8s16o16 + (n_full_pieces_loop_limit * KC_updated) + + (n0_partial_pack * KC_updated)), + (b + n_full_pieces_loop_limit + n0_partial_pack), + ldb, rows, n0_partial_rem); + } + } + + *rs_b = NR * 2; + *cs_b = NR; +} \ No newline at end of file diff --git a/addon/aocl_gemm/kernels/u8s8s16/lpgemm_packb_s16.h b/addon/aocl_gemm/kernels/u8s8s16/lpgemm_packb_s16.h new file mode 100644 index 0000000000..b8d73c862c --- /dev/null +++ b/addon/aocl_gemm/kernels/u8s8s16/lpgemm_packb_s16.h @@ -0,0 +1,55 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef BLIS_GEMM_INT16_PACKB +#define BLIS_GEMM_INT16_PACKB + +void get_packb_nr32_u8s8s16o16_strides + ( + dim_t* rs_b, + dim_t* cs_b + ); + +void packb_nr32_u8s8s16o16 + ( + int8_t *pack_b_buffer_u8s8s16o16, + const int8_t *b, + const dim_t ldb, + const dim_t cols, + const dim_t rows, + dim_t *rs_b, + dim_t *cs_b + ); + +#endif // BLIS_GEMM_INT16_PACKB \ No newline at end of file diff --git a/addon/aocl_gemm/kernels/u8s8s16/lpgemm_s16_kern_macros.h b/addon/aocl_gemm/kernels/u8s8s16/lpgemm_s16_kern_macros.h new file mode 100644 index 0000000000..00583977f3 --- /dev/null +++ b/addon/aocl_gemm/kernels/u8s8s16/lpgemm_s16_kern_macros.h @@ -0,0 +1,404 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef LPGEMM_S16_KERN_MACROS_H +#define LPGEMM_S16_KERN_MACROS_H +#define S8_MIN (-128) +#define S8_MAX (+127) + +#define RELU_SCALE_OP_S16_AVX2(reg) \ + selector1 = _mm256_setzero_si256();\ + selector1 = _mm256_cmpgt_epi16 ( selector1, reg ); \ + \ + /* Only < 0 elements in b0. */ \ + b0 = _mm256_and_si256 ( selector1, reg ); \ +\ + /* Only >= 0 elements in c_int16_0p0. */ \ + reg = _mm256_andnot_si256( selector1, reg ); \ + \ + /* Only scaling for < 0 elements. */ \ + b0 = _mm256_mullo_epi16( b0, selector2 ); \ + \ + /* Combine the scaled < 0 and >= 0 elements. */ \ + reg = _mm256_or_si256( b0, reg ); \ + \ + +//-------------------------------------------------------------------------- + +#define BLI_MM256_S16_DOWNSCALE(c_int16__p0, c_int16__p1, vec_loc)\ +\ + /* Extract the first 128 bits of the register*/\ + temp[0] = _mm256_extractf128_si256(c_int16__p0, 0);\ + /* Extract the second 128 bits of the register*/\ + temp[1] = _mm256_extractf128_si256(c_int16__p0, 1);\ +\ + temp_32[0] = _mm256_cvtepi16_epi32(temp[0]);\ + temp_32[1] = _mm256_cvtepi16_epi32(temp[1]);\ + temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]);\ + temp_float[1] = _mm256_cvtepi32_ps(temp_32[1]);\ +\ + /* Multiply the C matrix by the scale value*/\ + res_1 = _mm256_mul_ps(temp_float[0], scale_1);\ + res_2 = _mm256_mul_ps(temp_float[1], scale_2);\ +\ + /* Round the resultant value to the nearest float value and clip the values between [-128, 127] */\ + res_1 = _mm256_min_ps(_mm256_max_ps \ + (_mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)), \ + _mm256_set1_ps(( float )S8_MIN)), _mm256_set1_ps(( float )S8_MAX));\ + res_2 = _mm256_min_ps(_mm256_max_ps \ + (_mm256_round_ps (res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)), \ + _mm256_set1_ps(( float )S8_MIN)), _mm256_set1_ps(( float )S8_MAX));\ +\ + /* Convert the clipped float32 scaled rounded value to int32 */\ + temp_32[0] = _mm256_cvtps_epi32(res_1);\ + temp_32[1] = _mm256_cvtps_epi32(res_2);\ +\ + /* Convert the s32 to s16 */\ + c_int16__p0 = _mm256_packs_epi32(temp_32[0], temp_32[1]);\ +\ + /*Permute to make sure the order is correct*/\ + c_int16__p0 = _mm256_permute4x64_epi64(c_int16__p0, 0XD8);\ +\ + /* Extract the first 128 bits of the register*/\ + temp[0] = _mm256_extractf128_si256(c_int16__p1, 0);\ +\ + /* Extract the second 128 bits of the register*/\ + temp[1] = _mm256_extractf128_si256(c_int16__p1, 1);\ +\ + temp_32[0] = _mm256_cvtepi16_epi32(temp[0]);\ + temp_32[1] = _mm256_cvtepi16_epi32(temp[1]);\ + temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]);\ + temp_float[1] = _mm256_cvtepi32_ps(temp_32[1]);\ +\ + /* Multiply the C matrix by the scale value*/\ + res_1 = _mm256_mul_ps(temp_float[0], scale_1);\ + res_2 = _mm256_mul_ps(temp_float[1], scale_2);\ +\ + /* Round the resultant value to the nearest float value and clip the values between [-128, 127] */\ + res_1 = _mm256_min_ps(_mm256_max_ps \ + (_mm256_round_ps (res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)), \ + _mm256_set1_ps(( float )S8_MIN)), _mm256_set1_ps(( float )S8_MAX));\ + res_2 = _mm256_min_ps(_mm256_max_ps \ + (_mm256_round_ps (res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)), \ + _mm256_set1_ps(( float )S8_MIN)), _mm256_set1_ps(( float )S8_MAX));\ +\ + /* Convert the clipped float32 scaled rounded value to int32 */\ + temp_32[0] = _mm256_cvtps_epi32(res_1);\ + temp_32[1] = _mm256_cvtps_epi32(res_2);\ +\ + /* Convert the s32 to s16 */\ + c_int16__p1 = _mm256_packs_epi32(temp_32[0], temp_32[1]);\ +\ + /*Permute to make sure the order is correct*/\ + c_int16__p1 = _mm256_permute4x64_epi64(c_int16__p1, 0XD8);\ +\ + /* Convert the s16 to s8 */\ + store_reg = _mm256_packs_epi16(c_int16__p0, c_int16__p1);\ + store_reg = _mm256_permute4x64_epi64(store_reg, 0XD8);\ +\ + /* Store the result in s8 form */\ + _mm256_storeu_si256((__m256i *)(( int8_t* )post_ops_list_temp->op_args3 + \ + ( rs_c_downscale * ( post_op_c_i + vec_loc ) ) + post_op_c_j), store_reg);\ +\ + +//-------------------------------------------------------------------------- + +#define BLI_MM256_S16_DOWNSCALE2(c_int16__p0, c_int16__p1, vec_loc1, vec_loc2)\ +\ + /* Extract the first 128 bits of the register*/\ + temp[0] = _mm256_extractf128_si256(c_int16__p0, 0);\ + /* Extract the second 128 bits of the register*/\ + temp[1] = _mm256_extractf128_si256(c_int16__p0, 1);\ +\ + temp_32[0] = _mm256_cvtepi16_epi32(temp[0]);\ + temp_32[1] = _mm256_cvtepi16_epi32(temp[1]);\ + temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]);\ + temp_float[1] = _mm256_cvtepi32_ps(temp_32[1]);\ +\ + /* Multiply the C matrix by the scale value*/\ + res_1 = _mm256_mul_ps(temp_float[0], scale_1);\ + res_2 = _mm256_mul_ps(temp_float[1], scale_2);\ +\ + /* Round the resultant value to the nearest float value and clip the values between [-128, 127] */\ + res_1 = _mm256_min_ps(_mm256_max_ps \ + (_mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)), \ + _mm256_set1_ps(( float )S8_MIN)), _mm256_set1_ps(( float )S8_MAX));\ + res_2 = _mm256_min_ps(_mm256_max_ps \ + (_mm256_round_ps(res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)), \ + _mm256_set1_ps(( float )S8_MIN)), _mm256_set1_ps(( float )S8_MAX));\ +\ + /* Convert the clipped float32 scaled rounded value to int32 */\ + temp_32[0] = _mm256_cvtps_epi32(res_1);\ + temp_32[1] = _mm256_cvtps_epi32(res_2);\ +\ + /* Convert the s32 to s16 */\ + c_int16__p0 = _mm256_packs_epi32(temp_32[0], temp_32[1]);\ +\ + /*Permute to make sure the order is correct*/\ + c_int16__p0 = _mm256_permute4x64_epi64(c_int16__p0, 0XD8);\ +\ + /* Extract the first 128 bits of the register*/\ + temp[0] = _mm256_extractf128_si256(c_int16__p1, 0);\ +\ + /* Extract the second 128 bits of the register*/\ + temp[1] = _mm256_extractf128_si256(c_int16__p1, 1);\ +\ + temp_32[0] = _mm256_cvtepi16_epi32(temp[0]);\ + temp_32[1] = _mm256_cvtepi16_epi32(temp[1]);\ + temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]);\ + temp_float[1] = _mm256_cvtepi32_ps(temp_32[1]);\ +\ + /* Multiply the C matrix by the scale value*/\ + res_1 = _mm256_mul_ps(temp_float[0], scale_1);\ + res_2 = _mm256_mul_ps(temp_float[1], scale_2);\ +\ + /* Round the resultant value to the nearest float value and clip the values between [-128, 127] */\ + res_1 = _mm256_min_ps(_mm256_max_ps \ + (_mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)), \ + _mm256_set1_ps(( float )S8_MIN)), _mm256_set1_ps(( float )S8_MAX));\ + res_2 = _mm256_min_ps(_mm256_max_ps \ + (_mm256_round_ps(res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)), \ + _mm256_set1_ps(( float )S8_MIN)), _mm256_set1_ps(( float )S8_MAX));\ +\ + /* Convert the clipped float32 scaled rounded value to int32 */\ + temp_32[0] = _mm256_cvtps_epi32(res_1);\ + temp_32[1] = _mm256_cvtps_epi32(res_2);\ +\ + /* Convert the s32 to s16 */\ + c_int16__p1 = _mm256_packs_epi32(temp_32[0], temp_32[1]);\ +\ + /*Permute to make sure the order is correct*/\ + c_int16__p1 = _mm256_permute4x64_epi64(c_int16__p1, 0XD8);\ +\ + /* Convert the s16 to s8 */\ + store_reg = _mm256_packs_epi16(c_int16__p0, c_int16__p1);\ + store_reg = _mm256_permute4x64_epi64(store_reg, 0XD8);\ + /* Extract the first 128 bits of the register*/\ + temp[0] = _mm256_extractf128_si256(store_reg, 0);\ + /* Extract the second 128 bits of the register*/\ + temp[1] = _mm256_extractf128_si256(store_reg, 1);\ +\ + /* Store the result in s8 form */\ + _mm_storeu_si128((__m128i *)(( int8_t* )post_ops_list_temp->op_args3 + \ + ( rs_c_downscale * ( post_op_c_i + vec_loc1 ) ) + post_op_c_j), temp[0]);\ + _mm_storeu_si128((__m128i *)(( int8_t* )post_ops_list_temp->op_args3 + \ + ( rs_c_downscale * ( post_op_c_i + vec_loc2 ) ) + post_op_c_j), temp[1]);\ +\ + +//-------------------------------------------------------------------------- + +#define BLI_MM256_S16_DOWNSCALE2_LT16(c_int16__p0, c_int16__p1, vec_loc1, vec_loc2)\ +\ + /* Extract the first 128 bits of the register*/\ + temp[0] = _mm256_extractf128_si256(c_int16__p0, 0);\ + /* Extract the second 128 bits of the register*/\ + temp[1] = _mm256_extractf128_si256(c_int16__p0, 1);\ +\ + temp_32[0] = _mm256_cvtepi16_epi32(temp[0]);\ + temp_32[1] = _mm256_cvtepi16_epi32(temp[1]);\ + temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]);\ + temp_float[1] = _mm256_cvtepi32_ps(temp_32[1]);\ +\ + /* Multiply the C matrix by the scale value*/\ + res_1 = _mm256_mul_ps(temp_float[0], scale_1);\ + res_2 = _mm256_mul_ps(temp_float[1], scale_2);\ +\ + /* Round the resultant value to the nearest float value and clip the values between [-128, 127] */\ + res_1 = _mm256_min_ps(_mm256_max_ps \ + (_mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)), \ + _mm256_set1_ps (( float )S8_MIN)), _mm256_set1_ps (( float )S8_MAX));\ + res_2 = _mm256_min_ps(_mm256_max_ps \ + (_mm256_round_ps(res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)), \ + _mm256_set1_ps (( float )S8_MIN)), _mm256_set1_ps (( float )S8_MAX));\ +\ + /* Convert the clipped float32 scaled rounded value to int32 */\ + temp_32[0] = _mm256_cvtps_epi32(res_1);\ + temp_32[1] = _mm256_cvtps_epi32(res_2);\ +\ + /* Convert the s32 to s16 */\ + c_int16__p0 = _mm256_packs_epi32(temp_32[0], temp_32[1]);\ +\ + /*Permute to make sure the order is correct*/\ + c_int16__p0 = _mm256_permute4x64_epi64(c_int16__p0, 0XD8);\ +\ + /* Extract the first 128 bits of the register*/\ + temp[0] = _mm256_extractf128_si256(c_int16__p1, 0);\ +\ + /* Extract the second 128 bits of the register*/\ + temp[1] = _mm256_extractf128_si256(c_int16__p1, 1);\ +\ + temp_32[0] = _mm256_cvtepi16_epi32(temp[0]);\ + temp_32[1] = _mm256_cvtepi16_epi32(temp[1]);\ + temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]);\ + temp_float[1] = _mm256_cvtepi32_ps(temp_32[1]);\ +\ + /* Multiply the C matrix by the scale value*/\ + res_1 = _mm256_mul_ps(temp_float[0], scale_1);\ + res_2 = _mm256_mul_ps(temp_float[1], scale_2);\ +\ + /* Round the resultant value to the nearest float value and clip the values between [-128, 127] */\ + res_1 = _mm256_min_ps(_mm256_max_ps \ + (_mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)), \ + _mm256_set1_ps (( float )S8_MIN)), _mm256_set1_ps (( float )S8_MAX));\ + res_2 = _mm256_min_ps(_mm256_max_ps \ + (_mm256_round_ps(res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)), \ + _mm256_set1_ps (( float )S8_MIN)), _mm256_set1_ps (( float )S8_MAX));\ +\ + /* Convert the clipped float32 scaled rounded value to int32 */\ + temp_32[0] = _mm256_cvtps_epi32(res_1);\ + temp_32[1] = _mm256_cvtps_epi32(res_2);\ +\ + /* Convert the s32 to s16 */\ + c_int16__p1 = _mm256_packs_epi32(temp_32[0], temp_32[1]);\ +\ + /*Permute to make sure the order is correct*/\ + c_int16__p1 = _mm256_permute4x64_epi64(c_int16__p1, 0XD8);\ +\ + /* Convert the s16 to s8 */\ + store_reg = _mm256_packs_epi16(c_int16__p0, c_int16__p1);\ + store_reg = _mm256_permute4x64_epi64(store_reg, 0XD8);\ + /* Extract the first 128 bits of the register*/\ + temp[0] = _mm256_extractf128_si256(store_reg, 0);\ + /* Extract the second 128 bits of the register*/\ + temp[1] = _mm256_extractf128_si256(store_reg, 1);\ +\ + /* Store the result in s8 form */\ + _mm_storeu_si128((__m128i *)store_buf, temp[0]);\ + memcpy( ( int8_t* )post_ops_list_temp->op_args3 + \ + ( rs_c_downscale * ( post_op_c_i + vec_loc1 ) ) + post_op_c_j \ + , store_buf, ( n0_rem * sizeof( int8_t ) ) ); \ +\ + _mm_storeu_si128((__m128i *)store_buf, temp[1]);\ + memcpy( ( int8_t* )post_ops_list_temp->op_args3 + \ + ( rs_c_downscale * ( post_op_c_i + vec_loc2 ) ) + post_op_c_j \ + , store_buf, ( n0_rem * sizeof( int8_t ) ) ); \ +\ + +//-------------------------------------------------------------------------- + +#define BLI_MM256_S16_DOWNSCALE2_EDGE(c_int16__p0, vec_ind)\ +\ + /* Extract the first 128 bits of the register*/\ + temp[0] = _mm256_extractf128_si256(c_int16__p0, 0);\ + /* Extract the second 128 bits of the register*/\ + temp[1] = _mm256_extractf128_si256(c_int16__p0, 1);\ +\ + temp_32[0] = _mm256_cvtepi16_epi32(temp[0]);\ + temp_32[1] = _mm256_cvtepi16_epi32(temp[1]);\ + temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]);\ + temp_float[1] = _mm256_cvtepi32_ps(temp_32[1]);\ +\ + /* Multiply the C matrix by the scale value*/\ + res_1 = _mm256_mul_ps(temp_float[0], scale_1);\ + res_2 = _mm256_mul_ps(temp_float[1], scale_2);\ +\ + /* Round the resultant value to the nearest float value and clip the values between [-128, 127] */\ + res_1 = _mm256_min_ps(_mm256_max_ps \ + (_mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)), \ + _mm256_set1_ps (( float )S8_MIN)), _mm256_set1_ps (( float )S8_MAX));\ + res_2 = _mm256_min_ps(_mm256_max_ps \ + (_mm256_round_ps(res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)), \ + _mm256_set1_ps (( float )S8_MIN)), _mm256_set1_ps (( float )S8_MAX));\ +\ + /* Convert the clipped float32 scaled rounded value to int32 */\ + temp_32[0] = _mm256_cvtps_epi32(res_1);\ + temp_32[1] = _mm256_cvtps_epi32(res_2);\ +\ + /* Convert the s32 to s16 */\ + c_int16__p0 = _mm256_packs_epi32(temp_32[0], temp_32[1]);\ +\ + /*Permute to make sure the order is correct*/\ + c_int16__p0 = _mm256_permute4x64_epi64(c_int16__p0, 0XD8);\ +\ + /* Convert the s16 to s8 */\ + store_reg = _mm256_packs_epi16(c_int16__p0, zero_reg);\ + store_reg = _mm256_permute4x64_epi64(store_reg, 0XD8);\ + /* Extract the first 128 bits of the register*/\ + temp[0] = _mm256_extractf128_si256(store_reg, 0);\ +\ + /* Store the result in s8 form */\ + _mm_storeu_si128((__m128i *)(( int8_t* )post_ops_list_temp->op_args3 + \ + ( rs_c_downscale * ( post_op_c_i + vec_ind ) ) + post_op_c_j), temp[0]);\ +\ + +//-------------------------------------------------------------------------- + +#define BLI_MM256_S16_DOWNSCALE2_EDGE_LT16(c_int16__p0, vec_ind)\ +\ + /* Extract the first 128 bits of the register*/\ + temp[0] = _mm256_extractf128_si256(c_int16__p0, 0);\ + /* Extract the second 128 bits of the register*/\ + temp[1] = _mm256_extractf128_si256(c_int16__p0, 1);\ +\ + temp_32[0] = _mm256_cvtepi16_epi32(temp[0]);\ + temp_32[1] = _mm256_cvtepi16_epi32(temp[1]);\ + temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]);\ + temp_float[1] = _mm256_cvtepi32_ps(temp_32[1]);\ +\ + /* Multiply the C matrix by the scale value*/\ + res_1 = _mm256_mul_ps(temp_float[0], scale_1);\ + res_2 = _mm256_mul_ps(temp_float[1], scale_2);\ +\ + /* Round the resultant value to the nearest float value and clip the values between [-128, 127] */\ + res_1 = _mm256_min_ps(_mm256_max_ps \ + (_mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)), \ + _mm256_set1_ps (( float )S8_MIN)), _mm256_set1_ps (( float )S8_MAX));\ + res_2 = _mm256_min_ps(_mm256_max_ps \ + (_mm256_round_ps(res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)), \ + _mm256_set1_ps (( float )S8_MIN)), _mm256_set1_ps (( float )S8_MAX));\ +\ + /* Convert the clipped float32 scaled rounded value to int32 */\ + temp_32[0] = _mm256_cvtps_epi32(res_1);\ + temp_32[1] = _mm256_cvtps_epi32(res_2);\ +\ + /* Convert the s32 to s16 */\ + c_int16__p0 = _mm256_packs_epi32(temp_32[0], temp_32[1]);\ +\ + /*Permute to make sure the order is correct*/\ + c_int16__p0 = _mm256_permute4x64_epi64(c_int16__p0, 0XD8);\ +\ + /* Convert the s16 to s8 */\ + store_reg = _mm256_packs_epi16(c_int16__p0, zero_reg);\ + store_reg = _mm256_permute4x64_epi64(store_reg, 0XD8);\ + /* Extract the first 128 bits of the register*/\ + temp[0] = _mm256_extractf128_si256(store_reg, 0);\ +\ + /* Store the result in s8 form */\ + _mm_storeu_si128((__m128i *)store_buf, temp[0]);\ + memcpy( (( int8_t* )post_ops_list_temp->op_args3 + \ + ( rs_c_downscale * ( post_op_c_i + vec_ind ) ) + post_op_c_j) \ + ,store_buf, ( n0_rem * sizeof( int8_t ) ) ); \ +\ + +#endif //LPGEMM_S16_KERN_MACROS_H \ No newline at end of file diff --git a/addon/aocl_gemm/kernels/u8s8s32/lpgemm_6x64rowmajor_amd512vnni.c b/addon/aocl_gemm/kernels/u8s8s32/lpgemm_6x64rowmajor_amd512vnni.c new file mode 100644 index 0000000000..f249106a3c --- /dev/null +++ b/addon/aocl_gemm/kernels/u8s8s32/lpgemm_6x64rowmajor_amd512vnni.c @@ -0,0 +1,1061 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include + +#include "blis.h" +#include "lpgemm_kernels.h" +#include "lpgemm_s32_kern_macros.h" + +#ifdef BLIS_KERNELS_ZEN4 +// 6x64 int8o32 kernel +LPGEMM_MAIN_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6x64) +{ + static void* post_ops_labels[] = + { + &&POST_OPS_6x64_DISABLE, + &&POST_OPS_BIAS_6x64, + &&POST_OPS_RELU_6x64, + &&POST_OPS_RELU_SCALE_6x64, + &&POST_OPS_DOWNSCALE_6x64 + }; + + dim_t MR = 6; + dim_t NR = 64; + + dim_t m_full_pieces = m0 / MR; + dim_t m_full_pieces_loop_limit = m_full_pieces * MR; + dim_t m_partial_pieces = m0 % MR; + + dim_t k_full_pieces = k0 / 4; + dim_t k_partial_pieces = k0 % 4; + + uint32_t a_kfringe_buf = 0; + + if ( n0 < NR ) + { + dim_t n0_rem = n0 % 16; + + // Split into multiple smaller fringe kernels, so as to maximize + // vectorization. Any n0 < NR(64) can be expressed as n0 = 48 + n` + // or n0 = 32 + n` or n0 = 16 + n`, where n` < 16. + dim_t n0_48 = n0 / 48; + dim_t n0_32 = n0 / 32; + dim_t n0_16 = n0 / 16; + + // KC when not multiple of 4 will have padding to make it multiple of + // 4 in packed buffer. Also the k0 cannot be passed as the updated + // value since A matrix is not packed and requires original k0. + dim_t k0_updated = k0; + if ( k_partial_pieces > 0 ) + { + k0_updated += ( 4 - k_partial_pieces ); + } + + if ( n0_48 == 1 ) + { + lpgemm_rowvar_u8s8s32o32_6x48 + ( + m0, k0, + a, rs_a, cs_a, ps_a, + b, ( ( rs_b / 4 ) * 3 ), cs_b, + c, rs_c, + alpha, beta, + is_last_k, + post_op_c_i, post_op_c_j, + post_ops_list, rs_c_downscale + ); + + b = b + ( 48 * k0_updated ); // k0x48 packed contiguosly. + c = c + 48; + post_op_c_j += 48; + } + else if ( n0_32 == 1 ) + { + lpgemm_rowvar_u8s8s32o32_6x32 + ( + m0, k0, + a, rs_a, cs_a, ps_a, + b, ( ( rs_b / 4 ) * 2 ), cs_b, + c, rs_c, + alpha, beta, + is_last_k, + post_op_c_i, post_op_c_j, + post_ops_list, rs_c_downscale + ); + + b = b + ( 32 * k0_updated ); // k0x32 packed contiguosly. + c = c + 32; + post_op_c_j += 32; + } + else if ( n0_16 == 1 ) + { + lpgemm_rowvar_u8s8s32o32_6x16 + ( + m0, k0, + a, rs_a, cs_a, ps_a, + b, ( ( rs_b / 4 ) * 1 ), cs_b, + c, rs_c, + alpha, beta, + is_last_k, + post_op_c_i, post_op_c_j, + post_ops_list, rs_c_downscale + ); + + b = b + ( 16 * k0_updated ); // k0x16 packed contiguosly. + c = c + 16; + post_op_c_j += 16; + } + + if ( n0_rem > 0 ) + { + lpgemm_rowvar_u8s8s32o32_6xlt16 + ( + m0, k0, + a, rs_a, cs_a, ps_a, + b, ( ( rs_b / 4 ) * 1 ), cs_b, + c, rs_c, + alpha, beta, n0_rem, + is_last_k, + post_op_c_i, post_op_c_j, + post_ops_list, rs_c_downscale + ); + + // No leftover fringe after this point. + } + + return; + } + + // B matrix storage. + __m512i b0; + __m512i b1; + __m512i b2; + __m512i b3; + + // A matrix storage. + __m512i a_int32_0; + __m512i a_int32_1; + + for ( dim_t ir = 0; ir < m_full_pieces_loop_limit; ir += MR ) + { + // Registers to use for accumulating C. + __m512i c_int32_0p0 = _mm512_setzero_epi32(); + __m512i c_int32_0p1 = _mm512_setzero_epi32(); + __m512i c_int32_0p2 = _mm512_setzero_epi32(); + __m512i c_int32_0p3 = _mm512_setzero_epi32(); + + __m512i c_int32_1p0 = _mm512_setzero_epi32(); + __m512i c_int32_1p1 = _mm512_setzero_epi32(); + __m512i c_int32_1p2 = _mm512_setzero_epi32(); + __m512i c_int32_1p3 = _mm512_setzero_epi32(); + + __m512i c_int32_2p0 = _mm512_setzero_epi32(); + __m512i c_int32_2p1 = _mm512_setzero_epi32(); + __m512i c_int32_2p2 = _mm512_setzero_epi32(); + __m512i c_int32_2p3 = _mm512_setzero_epi32(); + + __m512i c_int32_3p0 = _mm512_setzero_epi32(); + __m512i c_int32_3p1 = _mm512_setzero_epi32(); + __m512i c_int32_3p2 = _mm512_setzero_epi32(); + __m512i c_int32_3p3 = _mm512_setzero_epi32(); + + __m512i c_int32_4p0 = _mm512_setzero_epi32(); + __m512i c_int32_4p1 = _mm512_setzero_epi32(); + __m512i c_int32_4p2 = _mm512_setzero_epi32(); + __m512i c_int32_4p3 = _mm512_setzero_epi32(); + + __m512i c_int32_5p0 = _mm512_setzero_epi32(); + __m512i c_int32_5p1 = _mm512_setzero_epi32(); + __m512i c_int32_5p2 = _mm512_setzero_epi32(); + __m512i c_int32_5p3 = _mm512_setzero_epi32(); + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + // The instructions are arranged in a mixed way to reduce data + // chain dependencies. + + // Load 4 rows with 64 elements each from B to 4 ZMM registers. It + // is to be noted that the B matrix is packed for use in vnni + // instructions and each load to ZMM register will have 4 elements + // along k direction and 16 elements across n directions, so 4x16 + // elements to a ZMM register. + b0 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + b1 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 1 ) ); + b2 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 2 ) ); + b3 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 3 ) ); + + // Perform column direction mat-mul with k = 4. + // c[0,0-63] = a[0,kr:kr+4]*b[kr:kr+4,0-63] + c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); + + // Broadcast a[1,kr:kr+4]. + a_int32_1 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + + c_int32_0p1 = _mm512_dpbusd_epi32( c_int32_0p1, a_int32_0, b1 ); + c_int32_0p2 = _mm512_dpbusd_epi32( c_int32_0p2, a_int32_0, b2 ); + c_int32_0p3 = _mm512_dpbusd_epi32( c_int32_0p3, a_int32_0, b3 ); + + // Perform column direction mat-mul with k = 4. + // c[1,0-63] = a[1,kr:kr+4]*b[kr:kr+4,0-63] + c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_1, b0 ); + + // Broadcast a[2,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); + + c_int32_1p1 = _mm512_dpbusd_epi32( c_int32_1p1, a_int32_1, b1 ); + c_int32_1p2 = _mm512_dpbusd_epi32( c_int32_1p2, a_int32_1, b2 ); + c_int32_1p3 = _mm512_dpbusd_epi32( c_int32_1p3, a_int32_1, b3 ); + + // Perform column direction mat-mul with k = 4. + // c[2,0-63] = a[2,kr:kr+4]*b[kr:kr+4,0-63] + c_int32_2p0 = _mm512_dpbusd_epi32( c_int32_2p0, a_int32_0, b0 ); + + // Broadcast a[3,kr:kr+4]. + a_int32_1 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 3 ) + ( cs_a * kr ) ) ); + + c_int32_2p1 = _mm512_dpbusd_epi32( c_int32_2p1, a_int32_0, b1 ); + c_int32_2p2 = _mm512_dpbusd_epi32( c_int32_2p2, a_int32_0, b2 ); + c_int32_2p3 = _mm512_dpbusd_epi32( c_int32_2p3, a_int32_0, b3 ); + + // Perform column direction mat-mul with k = 4. + // c[3,0-63] = a[3,kr:kr+4]*b[kr:kr+4,0-63] + c_int32_3p0 = _mm512_dpbusd_epi32( c_int32_3p0, a_int32_1, b0 ); + + // Broadcast a[4,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 4 ) + ( cs_a * kr ) ) ); + + c_int32_3p1 = _mm512_dpbusd_epi32( c_int32_3p1, a_int32_1, b1 ); + c_int32_3p2 = _mm512_dpbusd_epi32( c_int32_3p2, a_int32_1, b2 ); + c_int32_3p3 = _mm512_dpbusd_epi32( c_int32_3p3, a_int32_1, b3 ); + + // Perform column direction mat-mul with k = 4. + // c[4,0-63] = a[4,kr:kr+4]*b[kr:kr+4,0-63] + c_int32_4p0 = _mm512_dpbusd_epi32( c_int32_4p0, a_int32_0, b0 ); + + // Broadcast a[5,kr:kr+4]. + a_int32_1 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 5 ) + ( cs_a * kr ) ) ); + + c_int32_4p1 = _mm512_dpbusd_epi32( c_int32_4p1, a_int32_0, b1 ); + c_int32_4p2 = _mm512_dpbusd_epi32( c_int32_4p2, a_int32_0, b2 ); + c_int32_4p3 = _mm512_dpbusd_epi32( c_int32_4p3, a_int32_0, b3 ); + + // Perform column direction mat-mul with k = 4. + // c[5,0-63] = a[5,kr:kr+4]*b[kr:kr+4,0-63] + c_int32_5p0 = _mm512_dpbusd_epi32( c_int32_5p0, a_int32_1, b0 ); + c_int32_5p1 = _mm512_dpbusd_epi32( c_int32_5p1, a_int32_1, b1 ); + c_int32_5p2 = _mm512_dpbusd_epi32( c_int32_5p2, a_int32_1, b2 ); + c_int32_5p3 = _mm512_dpbusd_epi32( c_int32_5p3, a_int32_1, b3 ); + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + b0 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+4]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( uint8_t ) ) + ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + b1 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); + b2 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 2 ) ); + b3 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 3 ) ); + + // Perform column direction mat-mul with k = 4. + // c[0,0-63] = a[0,kr:kr+4]*b[kr:kr+4,0-63] + c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); + + // Broadcast a[1,kr:kr+4]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( uint8_t ) ) + ); + a_int32_1 = _mm512_set1_epi32( a_kfringe_buf ); + + c_int32_0p1 = _mm512_dpbusd_epi32( c_int32_0p1, a_int32_0, b1 ); + c_int32_0p2 = _mm512_dpbusd_epi32( c_int32_0p2, a_int32_0, b2 ); + c_int32_0p3 = _mm512_dpbusd_epi32( c_int32_0p3, a_int32_0, b3 ); + + // Perform column direction mat-mul with k = 4. + // c[1,0-63] = a[1,kr:kr+4]*b[kr:kr+4,0-63] + c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_1, b0 ); + + // Broadcast a[2,kr:kr+4]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 2 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( uint8_t ) ) + ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + c_int32_1p1 = _mm512_dpbusd_epi32( c_int32_1p1, a_int32_1, b1 ); + c_int32_1p2 = _mm512_dpbusd_epi32( c_int32_1p2, a_int32_1, b2 ); + c_int32_1p3 = _mm512_dpbusd_epi32( c_int32_1p3, a_int32_1, b3 ); + + // Perform column direction mat-mul with k = 4. + // c[2,0-63] = a[2,kr:kr+4]*b[kr:kr+4,0-63] + c_int32_2p0 = _mm512_dpbusd_epi32( c_int32_2p0, a_int32_0, b0 ); + + // Broadcast a[3,kr:kr+4]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 3 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( uint8_t ) ) + ); + a_int32_1 = _mm512_set1_epi32( a_kfringe_buf ); + + c_int32_2p1 = _mm512_dpbusd_epi32( c_int32_2p1, a_int32_0, b1 ); + c_int32_2p2 = _mm512_dpbusd_epi32( c_int32_2p2, a_int32_0, b2 ); + c_int32_2p3 = _mm512_dpbusd_epi32( c_int32_2p3, a_int32_0, b3 ); + + // Perform column direction mat-mul with k = 4. + // c[3,0-63] = a[3,kr:kr+4]*b[kr:kr+4,0-63] + c_int32_3p0 = _mm512_dpbusd_epi32( c_int32_3p0, a_int32_1, b0 ); + + // Broadcast a[4,kr:kr+4]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 4 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( uint8_t ) ) + ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + c_int32_3p1 = _mm512_dpbusd_epi32( c_int32_3p1, a_int32_1, b1 ); + c_int32_3p2 = _mm512_dpbusd_epi32( c_int32_3p2, a_int32_1, b2 ); + c_int32_3p3 = _mm512_dpbusd_epi32( c_int32_3p3, a_int32_1, b3 ); + + // Perform column direction mat-mul with k = 4. + // c[4,0-63] = a[4,kr:kr+4]*b[kr:kr+4,0-63] + c_int32_4p0 = _mm512_dpbusd_epi32( c_int32_4p0, a_int32_0, b0 ); + + // Broadcast a[5,kr:kr+4]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 5 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( uint8_t ) ) + ); + a_int32_1 = _mm512_set1_epi32( a_kfringe_buf ); + + c_int32_4p1 = _mm512_dpbusd_epi32( c_int32_4p1, a_int32_0, b1 ); + c_int32_4p2 = _mm512_dpbusd_epi32( c_int32_4p2, a_int32_0, b2 ); + c_int32_4p3 = _mm512_dpbusd_epi32( c_int32_4p3, a_int32_0, b3 ); + + // Perform column direction mat-mul with k = 4. + // c[5,0-63] = a[5,kr:kr+4]*b[kr:kr+4,0-63] + c_int32_5p0 = _mm512_dpbusd_epi32( c_int32_5p0, a_int32_1, b0 ); + c_int32_5p1 = _mm512_dpbusd_epi32( c_int32_5p1, a_int32_1, b1 ); + c_int32_5p2 = _mm512_dpbusd_epi32( c_int32_5p2, a_int32_1, b2 ); + c_int32_5p3 = _mm512_dpbusd_epi32( c_int32_5p3, a_int32_1, b3 ); + } + + // Load alpha and beta + __m512i selector1 = _mm512_set1_epi32( alpha ); + __m512i selector2 = _mm512_set1_epi32( beta ); + + // Scale by alpha + c_int32_0p0 = _mm512_mullo_epi32( selector1, c_int32_0p0 ); + c_int32_0p1 = _mm512_mullo_epi32( selector1, c_int32_0p1 ); + c_int32_0p2 = _mm512_mullo_epi32( selector1, c_int32_0p2 ); + c_int32_0p3 = _mm512_mullo_epi32( selector1, c_int32_0p3 ); + + c_int32_1p0 = _mm512_mullo_epi32( selector1, c_int32_1p0 ); + c_int32_1p1 = _mm512_mullo_epi32( selector1, c_int32_1p1 ); + c_int32_1p2 = _mm512_mullo_epi32( selector1, c_int32_1p2 ); + c_int32_1p3 = _mm512_mullo_epi32( selector1, c_int32_1p3 ); + + c_int32_2p0 = _mm512_mullo_epi32( selector1, c_int32_2p0 ); + c_int32_2p1 = _mm512_mullo_epi32( selector1, c_int32_2p1 ); + c_int32_2p2 = _mm512_mullo_epi32( selector1, c_int32_2p2 ); + c_int32_2p3 = _mm512_mullo_epi32( selector1, c_int32_2p3 ); + + c_int32_3p0 = _mm512_mullo_epi32( selector1, c_int32_3p0 ); + c_int32_3p1 = _mm512_mullo_epi32( selector1, c_int32_3p1 ); + c_int32_3p2 = _mm512_mullo_epi32( selector1, c_int32_3p2 ); + c_int32_3p3 = _mm512_mullo_epi32( selector1, c_int32_3p3 ); + + c_int32_4p0 = _mm512_mullo_epi32( selector1, c_int32_4p0 ); + c_int32_4p1 = _mm512_mullo_epi32( selector1, c_int32_4p1 ); + c_int32_4p2 = _mm512_mullo_epi32( selector1, c_int32_4p2 ); + c_int32_4p3 = _mm512_mullo_epi32( selector1, c_int32_4p3 ); + + c_int32_5p0 = _mm512_mullo_epi32( selector1, c_int32_5p0 ); + c_int32_5p1 = _mm512_mullo_epi32( selector1, c_int32_5p1 ); + c_int32_5p2 = _mm512_mullo_epi32( selector1, c_int32_5p2 ); + c_int32_5p3 = _mm512_mullo_epi32( selector1, c_int32_5p3 ); + + // Scale C by beta. + if ( beta != 0 ) + { + // c[0,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 0 ) ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); + + // c[0, 16-31] + selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 0 ) ) + ( 1*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_0p1 = _mm512_add_epi32( selector1, c_int32_0p1 ); + + // c[0,32-47] + selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 0 ) ) + ( 2*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_0p2 = _mm512_add_epi32( selector1, c_int32_0p2 ); + + // c[0,48-63] + selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 0 ) ) + ( 3*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_0p3 = _mm512_add_epi32( selector1, c_int32_0p3 ); + + // c[1,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 1 ) ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); + + // c[1,16-31] + selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 1 ) ) + ( 1*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_1p1 = _mm512_add_epi32( selector1, c_int32_1p1 ); + + // c[1,32-47] + selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 1 ) ) + ( 2*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_1p2 = _mm512_add_epi32( selector1, c_int32_1p2 ); + + // c[1,48-63] + selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 1 ) ) + ( 3*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_1p3 = _mm512_add_epi32( selector1, c_int32_1p3 ); + + // c[2,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 2 ) ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_2p0 = _mm512_add_epi32( selector1, c_int32_2p0 ); + + // c[2,16-31] + selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 2 ) ) + ( 1*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_2p1 = _mm512_add_epi32( selector1, c_int32_2p1 ); + + // c[2,32-47] + selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 2 ) ) + ( 2*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_2p2 = _mm512_add_epi32( selector1, c_int32_2p2 ); + + // c[2,48-63] + selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 2 ) ) + ( 3*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_2p3 = _mm512_add_epi32( selector1, c_int32_2p3 ); + + // c[3,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 3 ) ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_3p0 = _mm512_add_epi32( selector1, c_int32_3p0 ); + + // c[3,16-31] + selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 3 ) ) + ( 1*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_3p1 = _mm512_add_epi32( selector1, c_int32_3p1 ); + + // c[3,32-47] + selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 3 ) ) + ( 2*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_3p2 = _mm512_add_epi32( selector1, c_int32_3p2 ); + + // c[3,48-63] + selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 3 ) ) + ( 3*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_3p3 = _mm512_add_epi32( selector1, c_int32_3p3 ); + + // c[4,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 4 ) ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_4p0 = _mm512_add_epi32( selector1, c_int32_4p0 ); + + // c[4,16-31] + selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 4 ) ) + ( 1*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_4p1 = _mm512_add_epi32( selector1, c_int32_4p1 ); + + // c[4,32-47] + selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 4 ) ) + ( 2*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_4p2 = _mm512_add_epi32( selector1, c_int32_4p2 ); + + // c[4,48-63] + selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 4 ) ) + ( 3*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_4p3 = _mm512_add_epi32( selector1, c_int32_4p3 ); + + // c[5,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 5 ) ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_5p0 = _mm512_add_epi32( selector1, c_int32_5p0 ); + + // c[5,16-31] + selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 5 ) ) + ( 1*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_5p1 = _mm512_add_epi32( selector1, c_int32_5p1 ); + + // c[5,32-47] + selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 5 ) ) + ( 2*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_5p2 = _mm512_add_epi32( selector1, c_int32_5p2 ); + + // c[5,48-63] + selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 5 ) ) + ( 3*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_5p3 = _mm512_add_epi32( selector1, c_int32_5p3 ); + } + + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_6x64: + { + selector1 = + _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 1 * 16 ) ); + a_int32_0 = + _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 2 * 16 ) ); + a_int32_1 = + _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 3 * 16 ) ); + + // c[0,0-15] + c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); + + // c[0, 16-31] + c_int32_0p1 = _mm512_add_epi32( selector2, c_int32_0p1 ); + + // c[0,32-47] + c_int32_0p2 = _mm512_add_epi32( a_int32_0, c_int32_0p2 ); + + // c[0,48-63] + c_int32_0p3 = _mm512_add_epi32( a_int32_1, c_int32_0p3 ); + + // c[1,0-15] + c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); + + // c[1, 16-31] + c_int32_1p1 = _mm512_add_epi32( selector2, c_int32_1p1 ); + + // c[1,32-47] + c_int32_1p2 = _mm512_add_epi32( a_int32_0, c_int32_1p2 ); + + // c[1,48-63] + c_int32_1p3 = _mm512_add_epi32( a_int32_1, c_int32_1p3 ); + + // c[2,0-15] + c_int32_2p0 = _mm512_add_epi32( selector1, c_int32_2p0 ); + + // c[2, 16-31] + c_int32_2p1 = _mm512_add_epi32( selector2, c_int32_2p1 ); + + // c[2,32-47] + c_int32_2p2 = _mm512_add_epi32( a_int32_0, c_int32_2p2 ); + + // c[2,48-63] + c_int32_2p3 = _mm512_add_epi32( a_int32_1, c_int32_2p3 ); + + // c[3,0-15] + c_int32_3p0 = _mm512_add_epi32( selector1, c_int32_3p0 ); + + // c[3, 16-31] + c_int32_3p1 = _mm512_add_epi32( selector2, c_int32_3p1 ); + + // c[3,32-47] + c_int32_3p2 = _mm512_add_epi32( a_int32_0, c_int32_3p2 ); + + // c[3,48-63] + c_int32_3p3 = _mm512_add_epi32( a_int32_1, c_int32_3p3 ); + + // c[4,0-15] + c_int32_4p0 = _mm512_add_epi32( selector1, c_int32_4p0 ); + + // c[4, 16-31] + c_int32_4p1 = _mm512_add_epi32( selector2, c_int32_4p1 ); + + // c[4,32-47] + c_int32_4p2 = _mm512_add_epi32( a_int32_0, c_int32_4p2 ); + + // c[4,48-63] + c_int32_4p3 = _mm512_add_epi32( a_int32_1, c_int32_4p3 ); + + // c[5,0-15] + c_int32_5p0 = _mm512_add_epi32( selector1, c_int32_5p0 ); + + // c[5, 16-31] + c_int32_5p1 = _mm512_add_epi32( selector2, c_int32_5p1 ); + + // c[5,32-47] + c_int32_5p2 = _mm512_add_epi32( a_int32_0, c_int32_5p2 ); + + // c[5,48-63] + c_int32_5p3 = _mm512_add_epi32( a_int32_1, c_int32_5p3 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_6x64: + { + selector1 = _mm512_setzero_epi32(); + + // c[0,0-15] + c_int32_0p0 = _mm512_max_epi32( selector1, c_int32_0p0 ); + + // c[0, 16-31] + c_int32_0p1 = _mm512_max_epi32( selector1, c_int32_0p1 ); + + // c[0,32-47] + c_int32_0p2 = _mm512_max_epi32( selector1, c_int32_0p2 ); + + // c[0,48-63] + c_int32_0p3 = _mm512_max_epi32( selector1, c_int32_0p3 ); + + // c[1,0-15] + c_int32_1p0 = _mm512_max_epi32( selector1, c_int32_1p0 ); + + // c[1,16-31] + c_int32_1p1 = _mm512_max_epi32( selector1, c_int32_1p1 ); + + // c[1,32-47] + c_int32_1p2 = _mm512_max_epi32( selector1, c_int32_1p2 ); + + // c[1,48-63] + c_int32_1p3 = _mm512_max_epi32( selector1, c_int32_1p3 ); + + // c[2,0-15] + c_int32_2p0 = _mm512_max_epi32( selector1, c_int32_2p0 ); + + // c[2,16-31] + c_int32_2p1 = _mm512_max_epi32( selector1, c_int32_2p1 ); + + // c[2,32-47] + c_int32_2p2 = _mm512_max_epi32( selector1, c_int32_2p2 ); + + // c[2,48-63] + c_int32_2p3 = _mm512_max_epi32( selector1, c_int32_2p3 ); + + // c[3,0-15] + c_int32_3p0 = _mm512_max_epi32( selector1, c_int32_3p0 ); + + // c[3,16-31] + c_int32_3p1 = _mm512_max_epi32( selector1, c_int32_3p1 ); + + // c[3,32-47] + c_int32_3p2 = _mm512_max_epi32( selector1, c_int32_3p2 ); + + // c[3,48-63] + c_int32_3p3 = _mm512_max_epi32( selector1, c_int32_3p3 ); + + // c[4,0-15] + c_int32_4p0 = _mm512_max_epi32( selector1, c_int32_4p0 ); + + // c[4,16-31] + c_int32_4p1 = _mm512_max_epi32( selector1, c_int32_4p1 ); + + // c[4,32-47] + c_int32_4p2 = _mm512_max_epi32( selector1, c_int32_4p2 ); + + // c[4,48-63] + c_int32_4p3 = _mm512_max_epi32( selector1, c_int32_4p3 ); + + // c[5,0-15] + c_int32_5p0 = _mm512_max_epi32( selector1, c_int32_5p0 ); + + // c[5,16-31] + c_int32_5p1 = _mm512_max_epi32( selector1, c_int32_5p1 ); + + // c[5,32-47] + c_int32_5p2 = _mm512_max_epi32( selector1, c_int32_5p2 ); + + // c[5,48-63] + c_int32_5p3 = _mm512_max_epi32( selector1, c_int32_5p3 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_6x64: + { + selector1 = _mm512_setzero_epi32(); + selector2 = + _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_0p0) + + // c[0, 16-31] + RELU_SCALE_OP_S32_AVX512(c_int32_0p1) + + // c[0, 32-47] + RELU_SCALE_OP_S32_AVX512(c_int32_0p2) + + // c[0, 48-63] + RELU_SCALE_OP_S32_AVX512(c_int32_0p3) + + // c[1, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_1p0) + + // c[1, 16-31] + RELU_SCALE_OP_S32_AVX512(c_int32_1p1) + + // c[1, 32-47] + RELU_SCALE_OP_S32_AVX512(c_int32_1p2) + + // c[1, 48-63] + RELU_SCALE_OP_S32_AVX512(c_int32_1p3) + + // c[2, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_2p0) + + // c[2, 16-31] + RELU_SCALE_OP_S32_AVX512(c_int32_2p1) + + // c[2, 32-47] + RELU_SCALE_OP_S32_AVX512(c_int32_2p2) + + // c[2, 48-63] + RELU_SCALE_OP_S32_AVX512(c_int32_2p3) + + // c[3, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_3p0) + + // c[3, 16-31] + RELU_SCALE_OP_S32_AVX512(c_int32_3p1) + + // c[3, 32-47] + RELU_SCALE_OP_S32_AVX512(c_int32_3p2) + + // c[3, 48-63] + RELU_SCALE_OP_S32_AVX512(c_int32_3p3) + + // c[4, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_4p0) + + // c[4, 16-31] + RELU_SCALE_OP_S32_AVX512(c_int32_4p1) + + // c[4, 32-47] + RELU_SCALE_OP_S32_AVX512(c_int32_4p2) + + // c[4, 48-63] + RELU_SCALE_OP_S32_AVX512(c_int32_4p3) + + // c[5, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_5p0) + + // c[5, 16-31] + RELU_SCALE_OP_S32_AVX512(c_int32_5p1) + + // c[5, 32-47] + RELU_SCALE_OP_S32_AVX512(c_int32_5p2) + + // c[5, 48-63] + RELU_SCALE_OP_S32_AVX512(c_int32_5p3) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_DOWNSCALE_6x64: + { + selector1 = + _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j + ( 1 * 16 ) ); + a_int32_0 = + _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j + ( 2 * 16 ) ); + a_int32_1 = + _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j + ( 3 * 16 ) ); + + // c[0, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_0p0,selector1,0,0); + + // c[0, 16-31] + CVT_MULRND_CVT32_CVT8(c_int32_0p1,selector2,0,1); + + // c[0, 32-47] + CVT_MULRND_CVT32_CVT8(c_int32_0p2,a_int32_0,0,2); + + // c[0, 48-63] + CVT_MULRND_CVT32_CVT8(c_int32_0p3,a_int32_1,0,3); + + // c[1, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_1p0,selector1,1,0); + + // c[1, 16-31] + CVT_MULRND_CVT32_CVT8(c_int32_1p1,selector2,1,1); + + // c[1, 32-47] + CVT_MULRND_CVT32_CVT8(c_int32_1p2,a_int32_0,1,2); + + // c[1, 48-63] + CVT_MULRND_CVT32_CVT8(c_int32_1p3,a_int32_1,1,3); + + // c[2, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_2p0,selector1,2,0); + + // c[2, 16-31] + CVT_MULRND_CVT32_CVT8(c_int32_2p1,selector2,2,1); + + // c[2, 32-47] + CVT_MULRND_CVT32_CVT8(c_int32_2p2,a_int32_0,2,2); + + // c[2, 48-63] + CVT_MULRND_CVT32_CVT8(c_int32_2p3,a_int32_1,2,3); + + // c[3, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_3p0,selector1,3,0); + + // c[3, 16-31] + CVT_MULRND_CVT32_CVT8(c_int32_3p1,selector2,3,1); + + // c[3, 32-47] + CVT_MULRND_CVT32_CVT8(c_int32_3p2,a_int32_0,3,2); + + // c[3, 48-63] + CVT_MULRND_CVT32_CVT8(c_int32_3p3,a_int32_1,3,3); + + // c[4, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_4p0,selector1,4,0); + + // c[4, 16-31] + CVT_MULRND_CVT32_CVT8(c_int32_4p1,selector2,4,1); + + // c[4, 32-47] + CVT_MULRND_CVT32_CVT8(c_int32_4p2,a_int32_0,4,2); + + // c[4, 48-63] + CVT_MULRND_CVT32_CVT8(c_int32_4p3,a_int32_1,4,3); + + // c[5, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_5p0,selector1,5,0); + + // c[5, 16-31] + CVT_MULRND_CVT32_CVT8(c_int32_5p1,selector2,5,1); + + // c[5, 32-47] + CVT_MULRND_CVT32_CVT8(c_int32_5p2,a_int32_0,5,2); + + // c[5, 48-63] + CVT_MULRND_CVT32_CVT8(c_int32_5p3,a_int32_1,5,3); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_6x64_DISABLE: + ; + + // Store the results. + // c[0,0-15] + _mm512_storeu_epi32( c + ( rs_c * ( ir + 0 ) ) + ( 0*16 ), c_int32_0p0 ); + + // c[0, 16-31] + _mm512_storeu_epi32( c + ( rs_c * ( ir + 0 ) ) + ( 1*16 ), c_int32_0p1 ); + + // c[0,32-47] + _mm512_storeu_epi32( c + ( rs_c * ( ir + 0 ) ) + ( 2*16 ), c_int32_0p2 ); + + // c[0,48-63] + _mm512_storeu_epi32( c + ( rs_c * ( ir + 0 ) ) + ( 3*16 ), c_int32_0p3 ); + + // c[1,0-15] + _mm512_storeu_epi32( c + ( rs_c * ( ir + 1 ) ) + ( 0*16 ), c_int32_1p0 ); + + // c[1,16-31] + _mm512_storeu_epi32( c + ( rs_c * ( ir + 1 ) ) + ( 1*16 ), c_int32_1p1 ); + + // c[1,32-47] + _mm512_storeu_epi32( c + ( rs_c * ( ir + 1 ) ) + ( 2*16 ), c_int32_1p2 ); + + // c[1,48-63] + _mm512_storeu_epi32( c + ( rs_c * ( ir + 1 ) ) + ( 3*16 ), c_int32_1p3 ); + + // c[2,0-15] + _mm512_storeu_epi32( c + ( rs_c * ( ir + 2 ) ) + ( 0*16 ), c_int32_2p0 ); + + // c[2,16-31] + _mm512_storeu_epi32( c + ( rs_c * ( ir + 2 ) ) + ( 1*16 ), c_int32_2p1 ); + + // c[2,32-47] + _mm512_storeu_epi32( c + ( rs_c * ( ir + 2 ) ) + ( 2*16 ), c_int32_2p2 ); + + // c[2,48-63] + _mm512_storeu_epi32( c + ( rs_c * ( ir + 2 ) ) + ( 3*16 ), c_int32_2p3 ); + + // c[3,0-15] + _mm512_storeu_epi32( c + ( rs_c * ( ir + 3 ) ) + ( 0*16 ), c_int32_3p0 ); + + // c[3,16-31] + _mm512_storeu_epi32( c + ( rs_c * ( ir + 3 ) ) + ( 1*16 ), c_int32_3p1 ); + + // c[3,32-47] + _mm512_storeu_epi32( c + ( rs_c * ( ir + 3 ) ) + ( 2*16 ), c_int32_3p2 ); + + // c[3,48-63] + _mm512_storeu_epi32( c + ( rs_c * ( ir + 3 ) ) + ( 3*16 ), c_int32_3p3 ); + + // c[4,0-15] + _mm512_storeu_epi32( c + ( rs_c * ( ir + 4 ) ) + ( 0*16 ), c_int32_4p0 ); + + // c[4,16-31] + _mm512_storeu_epi32( c + ( rs_c * ( ir + 4 ) ) + ( 1*16 ), c_int32_4p1 ); + + // c[4,32-47] + _mm512_storeu_epi32( c + ( rs_c * ( ir + 4 ) ) + ( 2*16 ), c_int32_4p2 ); + + // c[4,48-63] + _mm512_storeu_epi32( c + ( rs_c * ( ir + 4 ) ) + ( 3*16 ), c_int32_4p3 ); + + // c[5,0-15] + _mm512_storeu_epi32( c + ( rs_c * ( ir + 5 ) ) + ( 0*16 ), c_int32_5p0 ); + + // c[5,16-31] + _mm512_storeu_epi32( c + ( rs_c * ( ir + 5 ) ) + ( 1*16 ), c_int32_5p1 ); + + // c[5,32-47] + _mm512_storeu_epi32( c + ( rs_c * ( ir + 5 ) ) + ( 2*16 ), c_int32_5p2 ); + + // c[5,48-63] + _mm512_storeu_epi32( c + ( rs_c * ( ir + 5 ) ) + ( 3*16 ), c_int32_5p3 ); + + a = a + ( MR * ps_a ); + post_op_c_i += MR; + } + + if ( m_partial_pieces > 0 ) + { + if ( m_partial_pieces == 5 ) + { + // In cases where A matrix is packed cs_a is set to 24, since the + // next column in a given row is accessed after 4*6 elements, where + // 6 is MR and 4 elements are broadcasted each time from A (vnni). + // In fringe case, where m < MR, the next column will be after m'*4 + // elements, and subsequently following adjustment of cs_a is + // required before calling m fringe kernels. + dim_t cs_a_use = ( cs_a == 4 ) ? 4 : ( ( cs_a / 6 ) * 5 ); + lpgemm_rowvar_u8s8s32o32_5x64 + ( + k0, + a, rs_a, cs_a_use, + b, rs_b, cs_b, + ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, + alpha, beta, + is_last_k, + post_op_c_i, post_op_c_j, + post_ops_list, rs_c_downscale + ); + } + else if ( m_partial_pieces == 4 ) + { + dim_t cs_a_use = ( cs_a == 4 ) ? 4 : ( ( cs_a / 6 ) * 4 ); + lpgemm_rowvar_u8s8s32o32_4x64 + ( + k0, + a, rs_a, cs_a_use, + b, rs_b, cs_b, + ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, + alpha, beta, + is_last_k, + post_op_c_i, post_op_c_j, + post_ops_list, rs_c_downscale + ); + } + else if ( m_partial_pieces == 3 ) + { + dim_t cs_a_use = ( cs_a == 4 ) ? 4 : ( ( cs_a / 6 ) * 3 ); + lpgemm_rowvar_u8s8s32o32_3x64 + ( + k0, + a, rs_a, cs_a_use, + b, rs_b, cs_b, + ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, + alpha, beta, + is_last_k, + post_op_c_i, post_op_c_j, + post_ops_list, rs_c_downscale + ); + } + else if ( m_partial_pieces == 2 ) + { + dim_t cs_a_use = ( cs_a == 4 ) ? 4 : ( ( cs_a / 6 ) * 2 ); + lpgemm_rowvar_u8s8s32o32_2x64 + ( + k0, + a, rs_a, cs_a_use, + b, rs_b, cs_b, + ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, + alpha, beta, + is_last_k, + post_op_c_i, post_op_c_j, + post_ops_list, rs_c_downscale + ); + } + else if ( m_partial_pieces == 1 ) + { + dim_t cs_a_use = ( cs_a == 4 ) ? 4 : ( ( cs_a / 6 ) * 1 ); + lpgemm_rowvar_u8s8s32o32_1x64 + ( + k0, + a, rs_a, cs_a_use, + b, rs_b, cs_b, + ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, + alpha, beta, + is_last_k, + post_op_c_i, post_op_c_j, + post_ops_list, rs_c_downscale + ); + } + } +} +#endif diff --git a/addon/aocl_gemm/kernels/u8s8s32/lpgemm_m_fringe_amd512vnni.c b/addon/aocl_gemm/kernels/u8s8s32/lpgemm_m_fringe_amd512vnni.c new file mode 100644 index 0000000000..1674a22bd0 --- /dev/null +++ b/addon/aocl_gemm/kernels/u8s8s32/lpgemm_m_fringe_amd512vnni.c @@ -0,0 +1,2362 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include + +#include "blis.h" +#include "lpgemm_kernels.h" +#include "lpgemm_s32_kern_macros.h" + +#ifdef BLIS_KERNELS_ZEN4 +// 5x64 int8o32 kernel +LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_5x64) +{ + static void* post_ops_labels[] = + { + &&POST_OPS_5x64_DISABLE, + &&POST_OPS_BIAS_5x64, + &&POST_OPS_RELU_5x64, + &&POST_OPS_RELU_SCALE_5x64, + &&POST_OPS_DOWNSCALE_5x64 + }; + dim_t k_full_pieces = k0 / 4; + dim_t k_partial_pieces = k0 % 4; + + uint32_t a_kfringe_buf = 0; + + // B matrix storage. + __m512i b0; + __m512i b1; + __m512i b2; + __m512i b3; + + // A matrix storage. + __m512i a_int32_0; + __m512i a_int32_1; + + // Registers to use for accumulating C. + __m512i c_int32_0p0 = _mm512_setzero_epi32(); + __m512i c_int32_0p1 = _mm512_setzero_epi32(); + __m512i c_int32_0p2 = _mm512_setzero_epi32(); + __m512i c_int32_0p3 = _mm512_setzero_epi32(); + + __m512i c_int32_1p0 = _mm512_setzero_epi32(); + __m512i c_int32_1p1 = _mm512_setzero_epi32(); + __m512i c_int32_1p2 = _mm512_setzero_epi32(); + __m512i c_int32_1p3 = _mm512_setzero_epi32(); + + __m512i c_int32_2p0 = _mm512_setzero_epi32(); + __m512i c_int32_2p1 = _mm512_setzero_epi32(); + __m512i c_int32_2p2 = _mm512_setzero_epi32(); + __m512i c_int32_2p3 = _mm512_setzero_epi32(); + + __m512i c_int32_3p0 = _mm512_setzero_epi32(); + __m512i c_int32_3p1 = _mm512_setzero_epi32(); + __m512i c_int32_3p2 = _mm512_setzero_epi32(); + __m512i c_int32_3p3 = _mm512_setzero_epi32(); + + __m512i c_int32_4p0 = _mm512_setzero_epi32(); + __m512i c_int32_4p1 = _mm512_setzero_epi32(); + __m512i c_int32_4p2 = _mm512_setzero_epi32(); + __m512i c_int32_4p3 = _mm512_setzero_epi32(); + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + b0 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + b1 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 1 ) ); + b2 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 2 ) ); + b3 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 3 ) ); + + // Perform column direction mat-mul with k = 4. + // c[0,0-63] = a[0,kr:kr+4]*b[kr:kr+4,0-63] + c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); + + // Broadcast a[1,kr:kr+4]. + a_int32_1 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + + c_int32_0p1 = _mm512_dpbusd_epi32( c_int32_0p1, a_int32_0, b1 ); + c_int32_0p2 = _mm512_dpbusd_epi32( c_int32_0p2, a_int32_0, b2 ); + c_int32_0p3 = _mm512_dpbusd_epi32( c_int32_0p3, a_int32_0, b3 ); + + // Perform column direction mat-mul with k = 4. + // c[1,0-63] = a[1,kr:kr+4]*b[kr:kr+4,0-63] + c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_1, b0 ); + + // Broadcast a[2,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); + + c_int32_1p1 = _mm512_dpbusd_epi32( c_int32_1p1, a_int32_1, b1 ); + c_int32_1p2 = _mm512_dpbusd_epi32( c_int32_1p2, a_int32_1, b2 ); + c_int32_1p3 = _mm512_dpbusd_epi32( c_int32_1p3, a_int32_1, b3 ); + + // Perform column direction mat-mul with k = 4. + // c[2,0-63] = a[2,kr:kr+4]*b[kr:kr+4,0-63] + c_int32_2p0 = _mm512_dpbusd_epi32( c_int32_2p0, a_int32_0, b0 ); + + // Broadcast a[3,kr:kr+4]. + a_int32_1 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 3 ) + ( cs_a * kr ) ) ); + + c_int32_2p1 = _mm512_dpbusd_epi32( c_int32_2p1, a_int32_0, b1 ); + c_int32_2p2 = _mm512_dpbusd_epi32( c_int32_2p2, a_int32_0, b2 ); + c_int32_2p3 = _mm512_dpbusd_epi32( c_int32_2p3, a_int32_0, b3 ); + + // Perform column direction mat-mul with k = 4. + // c[3,0-63] = a[3,kr:kr+4]*b[kr:kr+4,0-63] + c_int32_3p0 = _mm512_dpbusd_epi32( c_int32_3p0, a_int32_1, b0 ); + + // Broadcast a[4,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 4 ) + ( cs_a * kr ) ) ); + + c_int32_3p1 = _mm512_dpbusd_epi32( c_int32_3p1, a_int32_1, b1 ); + c_int32_3p2 = _mm512_dpbusd_epi32( c_int32_3p2, a_int32_1, b2 ); + c_int32_3p3 = _mm512_dpbusd_epi32( c_int32_3p3, a_int32_1, b3 ); + + // Perform column direction mat-mul with k = 4. + // c[4,0-63] = a[4,kr:kr+4]*b[kr:kr+4,0-63] + c_int32_4p0 = _mm512_dpbusd_epi32( c_int32_4p0, a_int32_0, b0 ); + c_int32_4p1 = _mm512_dpbusd_epi32( c_int32_4p1, a_int32_0, b1 ); + c_int32_4p2 = _mm512_dpbusd_epi32( c_int32_4p2, a_int32_0, b2 ); + c_int32_4p3 = _mm512_dpbusd_epi32( c_int32_4p3, a_int32_0, b3 ); + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + b0 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+4]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( uint8_t ) ) + ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + b1 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); + b2 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 2 ) ); + b3 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 3 ) ); + + // Perform column direction mat-mul with k = 4. + // c[0,0-63] = a[0,kr:kr+4]*b[kr:kr+4,0-63] + c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); + + // Broadcast a[1,kr:kr+4]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( uint8_t ) ) + ); + a_int32_1 = _mm512_set1_epi32( a_kfringe_buf ); + + c_int32_0p1 = _mm512_dpbusd_epi32( c_int32_0p1, a_int32_0, b1 ); + c_int32_0p2 = _mm512_dpbusd_epi32( c_int32_0p2, a_int32_0, b2 ); + c_int32_0p3 = _mm512_dpbusd_epi32( c_int32_0p3, a_int32_0, b3 ); + + // Perform column direction mat-mul with k = 4. + // c[1,0-63] = a[1,kr:kr+4]*b[kr:kr+4,0-63] + c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_1, b0 ); + + // Broadcast a[2,kr:kr+4]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 2 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( uint8_t ) ) + ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + c_int32_1p1 = _mm512_dpbusd_epi32( c_int32_1p1, a_int32_1, b1 ); + c_int32_1p2 = _mm512_dpbusd_epi32( c_int32_1p2, a_int32_1, b2 ); + c_int32_1p3 = _mm512_dpbusd_epi32( c_int32_1p3, a_int32_1, b3 ); + + // Perform column direction mat-mul with k = 4. + // c[2,0-63] = a[2,kr:kr+4]*b[kr:kr+4,0-63] + c_int32_2p0 = _mm512_dpbusd_epi32( c_int32_2p0, a_int32_0, b0 ); + + // Broadcast a[3,kr:kr+4]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 3 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( uint8_t ) ) + ); + a_int32_1 = _mm512_set1_epi32( a_kfringe_buf ); + + c_int32_2p1 = _mm512_dpbusd_epi32( c_int32_2p1, a_int32_0, b1 ); + c_int32_2p2 = _mm512_dpbusd_epi32( c_int32_2p2, a_int32_0, b2 ); + c_int32_2p3 = _mm512_dpbusd_epi32( c_int32_2p3, a_int32_0, b3 ); + + // Perform column direction mat-mul with k = 4. + // c[3,0-63] = a[3,kr:kr+4]*b[kr:kr+4,0-63] + c_int32_3p0 = _mm512_dpbusd_epi32( c_int32_3p0, a_int32_1, b0 ); + + // Broadcast a[4,kr:kr+4]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 4 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( uint8_t ) ) + ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + c_int32_3p1 = _mm512_dpbusd_epi32( c_int32_3p1, a_int32_1, b1 ); + c_int32_3p2 = _mm512_dpbusd_epi32( c_int32_3p2, a_int32_1, b2 ); + c_int32_3p3 = _mm512_dpbusd_epi32( c_int32_3p3, a_int32_1, b3 ); + + // Perform column direction mat-mul with k = 4. + // c[4,0-63] = a[4,kr:kr+4]*b[kr:kr+4,0-63] + c_int32_4p0 = _mm512_dpbusd_epi32( c_int32_4p0, a_int32_0, b0 ); + c_int32_4p1 = _mm512_dpbusd_epi32( c_int32_4p1, a_int32_0, b1 ); + c_int32_4p2 = _mm512_dpbusd_epi32( c_int32_4p2, a_int32_0, b2 ); + c_int32_4p3 = _mm512_dpbusd_epi32( c_int32_4p3, a_int32_0, b3 ); + } + + // Load alpha and beta + __m512i selector1 = _mm512_set1_epi32( alpha ); + __m512i selector2 = _mm512_set1_epi32( beta ); + + // Scale by alpha + c_int32_0p0 = _mm512_mullo_epi32( selector1, c_int32_0p0 ); + c_int32_0p1 = _mm512_mullo_epi32( selector1, c_int32_0p1 ); + c_int32_0p2 = _mm512_mullo_epi32( selector1, c_int32_0p2 ); + c_int32_0p3 = _mm512_mullo_epi32( selector1, c_int32_0p3 ); + + c_int32_1p0 = _mm512_mullo_epi32( selector1, c_int32_1p0 ); + c_int32_1p1 = _mm512_mullo_epi32( selector1, c_int32_1p1 ); + c_int32_1p2 = _mm512_mullo_epi32( selector1, c_int32_1p2 ); + c_int32_1p3 = _mm512_mullo_epi32( selector1, c_int32_1p3 ); + + c_int32_2p0 = _mm512_mullo_epi32( selector1, c_int32_2p0 ); + c_int32_2p1 = _mm512_mullo_epi32( selector1, c_int32_2p1 ); + c_int32_2p2 = _mm512_mullo_epi32( selector1, c_int32_2p2 ); + c_int32_2p3 = _mm512_mullo_epi32( selector1, c_int32_2p3 ); + + c_int32_3p0 = _mm512_mullo_epi32( selector1, c_int32_3p0 ); + c_int32_3p1 = _mm512_mullo_epi32( selector1, c_int32_3p1 ); + c_int32_3p2 = _mm512_mullo_epi32( selector1, c_int32_3p2 ); + c_int32_3p3 = _mm512_mullo_epi32( selector1, c_int32_3p3 ); + + c_int32_4p0 = _mm512_mullo_epi32( selector1, c_int32_4p0 ); + c_int32_4p1 = _mm512_mullo_epi32( selector1, c_int32_4p1 ); + c_int32_4p2 = _mm512_mullo_epi32( selector1, c_int32_4p2 ); + c_int32_4p3 = _mm512_mullo_epi32( selector1, c_int32_4p3 ); + + // Scale C by beta. + if ( beta != 0 ) + { + // c[0,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); + + // c[0, 16-31] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 1*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_0p1 = _mm512_add_epi32( selector1, c_int32_0p1 ); + + // c[0,32-47] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 2*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_0p2 = _mm512_add_epi32( selector1, c_int32_0p2 ); + + // c[0,48-63] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 3*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_0p3 = _mm512_add_epi32( selector1, c_int32_0p3 ); + + // c[1,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 1 ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); + + // c[1,16-31] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 1 ) + ( 1*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_1p1 = _mm512_add_epi32( selector1, c_int32_1p1 ); + + // c[1,32-47] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 1 ) + ( 2*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_1p2 = _mm512_add_epi32( selector1, c_int32_1p2 ); + + // c[1,48-63] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 1 ) + ( 3*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_1p3 = _mm512_add_epi32( selector1, c_int32_1p3 ); + + // c[2,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 2 ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_2p0 = _mm512_add_epi32( selector1, c_int32_2p0 ); + + // c[2,16-31] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 2 ) + ( 1*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_2p1 = _mm512_add_epi32( selector1, c_int32_2p1 ); + + // c[2,32-47] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 2 ) + ( 2*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_2p2 = _mm512_add_epi32( selector1, c_int32_2p2 ); + + // c[2,48-63] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 2 ) + ( 3*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_2p3 = _mm512_add_epi32( selector1, c_int32_2p3 ); + + // c[3,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 3 ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_3p0 = _mm512_add_epi32( selector1, c_int32_3p0 ); + + // c[3,16-31] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 3 ) + ( 1*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_3p1 = _mm512_add_epi32( selector1, c_int32_3p1 ); + + // c[3,32-47] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 3 ) + ( 2*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_3p2 = _mm512_add_epi32( selector1, c_int32_3p2 ); + + // c[3,48-63] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 3 ) + ( 3*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_3p3 = _mm512_add_epi32( selector1, c_int32_3p3 ); + + // c[4,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 4 ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_4p0 = _mm512_add_epi32( selector1, c_int32_4p0 ); + + // c[4,16-31] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 4 ) + ( 1*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_4p1 = _mm512_add_epi32( selector1, c_int32_4p1 ); + + // c[4,32-47] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 4 ) + ( 2*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_4p2 = _mm512_add_epi32( selector1, c_int32_4p2 ); + + // c[4,48-63] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 4 ) + ( 3*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_4p3 = _mm512_add_epi32( selector1, c_int32_4p3 ); + } + + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_5x64: + { + selector1 = + _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + + post_op_c_j ); + selector2 = + _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 1 * 16 ) ); + a_int32_0 = + _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 2 * 16 ) ); + a_int32_1 = + _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 3 * 16 ) ); + + // c[0,0-15] + c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); + + // c[0, 16-31] + c_int32_0p1 = _mm512_add_epi32( selector2, c_int32_0p1 ); + + // c[0,32-47] + c_int32_0p2 = _mm512_add_epi32( a_int32_0, c_int32_0p2 ); + + // c[0,48-63] + c_int32_0p3 = _mm512_add_epi32( a_int32_1, c_int32_0p3 ); + + // c[1,0-15] + c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); + + // c[1, 16-31] + c_int32_1p1 = _mm512_add_epi32( selector2, c_int32_1p1 ); + + // c[1,32-47] + c_int32_1p2 = _mm512_add_epi32( a_int32_0, c_int32_1p2 ); + + // c[1,48-63] + c_int32_1p3 = _mm512_add_epi32( a_int32_1, c_int32_1p3 ); + + // c[2,0-15] + c_int32_2p0 = _mm512_add_epi32( selector1, c_int32_2p0 ); + + // c[2, 16-31] + c_int32_2p1 = _mm512_add_epi32( selector2, c_int32_2p1 ); + + // c[2,32-47] + c_int32_2p2 = _mm512_add_epi32( a_int32_0, c_int32_2p2 ); + + // c[2,48-63] + c_int32_2p3 = _mm512_add_epi32( a_int32_1, c_int32_2p3 ); + + // c[3,0-15] + c_int32_3p0 = _mm512_add_epi32( selector1, c_int32_3p0 ); + + // c[3, 16-31] + c_int32_3p1 = _mm512_add_epi32( selector2, c_int32_3p1 ); + + // c[3,32-47] + c_int32_3p2 = _mm512_add_epi32( a_int32_0, c_int32_3p2 ); + + // c[3,48-63] + c_int32_3p3 = _mm512_add_epi32( a_int32_1, c_int32_3p3 ); + + // c[4,0-15] + c_int32_4p0 = _mm512_add_epi32( selector1, c_int32_4p0 ); + + // c[4, 16-31] + c_int32_4p1 = _mm512_add_epi32( selector2, c_int32_4p1 ); + + // c[4,32-47] + c_int32_4p2 = _mm512_add_epi32( a_int32_0, c_int32_4p2 ); + + // c[4,48-63] + c_int32_4p3 = _mm512_add_epi32( a_int32_1, c_int32_4p3 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_5x64: + { + selector1 = _mm512_setzero_epi32(); + + // c[0,0-15] + c_int32_0p0 = _mm512_max_epi32( selector1, c_int32_0p0 ); + + // c[0, 16-31] + c_int32_0p1 = _mm512_max_epi32( selector1, c_int32_0p1 ); + + // c[0,32-47] + c_int32_0p2 = _mm512_max_epi32( selector1, c_int32_0p2 ); + + // c[0,48-63] + c_int32_0p3 = _mm512_max_epi32( selector1, c_int32_0p3 ); + + // c[1,0-15] + c_int32_1p0 = _mm512_max_epi32( selector1, c_int32_1p0 ); + + // c[1,16-31] + c_int32_1p1 = _mm512_max_epi32( selector1, c_int32_1p1 ); + + // c[1,32-47] + c_int32_1p2 = _mm512_max_epi32( selector1, c_int32_1p2 ); + + // c[1,48-63] + c_int32_1p3 = _mm512_max_epi32( selector1, c_int32_1p3 ); + + // c[2,0-15] + c_int32_2p0 = _mm512_max_epi32( selector1, c_int32_2p0 ); + + // c[2,16-31] + c_int32_2p1 = _mm512_max_epi32( selector1, c_int32_2p1 ); + + // c[2,32-47] + c_int32_2p2 = _mm512_max_epi32( selector1, c_int32_2p2 ); + + // c[2,48-63] + c_int32_2p3 = _mm512_max_epi32( selector1, c_int32_2p3 ); + + // c[3,0-15] + c_int32_3p0 = _mm512_max_epi32( selector1, c_int32_3p0 ); + + // c[3,16-31] + c_int32_3p1 = _mm512_max_epi32( selector1, c_int32_3p1 ); + + // c[3,32-47] + c_int32_3p2 = _mm512_max_epi32( selector1, c_int32_3p2 ); + + // c[3,48-63] + c_int32_3p3 = _mm512_max_epi32( selector1, c_int32_3p3 ); + + // c[4,0-15] + c_int32_4p0 = _mm512_max_epi32( selector1, c_int32_4p0 ); + + // c[4,16-31] + c_int32_4p1 = _mm512_max_epi32( selector1, c_int32_4p1 ); + + // c[4,32-47] + c_int32_4p2 = _mm512_max_epi32( selector1, c_int32_4p2 ); + + // c[4,48-63] + c_int32_4p3 = _mm512_max_epi32( selector1, c_int32_4p3 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_5x64: + { + selector1 = _mm512_setzero_epi32(); + selector2 = + _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_0p0) + + // c[0, 16-31] + RELU_SCALE_OP_S32_AVX512(c_int32_0p1) + + // c[0, 32-47] + RELU_SCALE_OP_S32_AVX512(c_int32_0p2) + + // c[0, 48-63] + RELU_SCALE_OP_S32_AVX512(c_int32_0p3) + + // c[1, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_1p0) + + // c[1, 16-31] + RELU_SCALE_OP_S32_AVX512(c_int32_1p1) + + // c[1, 32-47] + RELU_SCALE_OP_S32_AVX512(c_int32_1p2) + + // c[1, 48-63] + RELU_SCALE_OP_S32_AVX512(c_int32_1p3) + + // c[2, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_2p0) + + // c[2, 16-31] + RELU_SCALE_OP_S32_AVX512(c_int32_2p1) + + // c[2, 32-47] + RELU_SCALE_OP_S32_AVX512(c_int32_2p2) + + // c[2, 48-63] + RELU_SCALE_OP_S32_AVX512(c_int32_2p3) + + // c[3, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_3p0) + + // c[3, 16-31] + RELU_SCALE_OP_S32_AVX512(c_int32_3p1) + + // c[3, 32-47] + RELU_SCALE_OP_S32_AVX512(c_int32_3p2) + + // c[3, 48-63] + RELU_SCALE_OP_S32_AVX512(c_int32_3p3) + + // c[4, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_4p0) + + // c[4, 16-31] + RELU_SCALE_OP_S32_AVX512(c_int32_4p1) + + // c[4, 32-47] + RELU_SCALE_OP_S32_AVX512(c_int32_4p2) + + // c[4, 48-63] + RELU_SCALE_OP_S32_AVX512(c_int32_4p3) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_DOWNSCALE_5x64: + { + selector1 = + _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j + ( 1 * 16 ) ); + a_int32_0 = + _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j + ( 2 * 16 ) ); + a_int32_1 = + _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j + ( 3 * 16 ) ); + + // c[0, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_0p0,selector1,0,0); + + // c[0, 16-31] + CVT_MULRND_CVT32_CVT8(c_int32_0p1,selector2,0,1); + + // c[0, 32-47] + CVT_MULRND_CVT32_CVT8(c_int32_0p2,a_int32_0,0,2); + + // c[0, 48-63] + CVT_MULRND_CVT32_CVT8(c_int32_0p3,a_int32_1,0,3); + + // c[1, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_1p0,selector1,1,0); + + // c[1, 16-31] + CVT_MULRND_CVT32_CVT8(c_int32_1p1,selector2,1,1); + + // c[1, 32-47] + CVT_MULRND_CVT32_CVT8(c_int32_1p2,a_int32_0,1,2); + + // c[1, 48-63] + CVT_MULRND_CVT32_CVT8(c_int32_1p3,a_int32_1,1,3); + + // c[2, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_2p0,selector1,2,0); + + // c[2, 16-31] + CVT_MULRND_CVT32_CVT8(c_int32_2p1,selector2,2,1); + + // c[2, 32-47] + CVT_MULRND_CVT32_CVT8(c_int32_2p2,a_int32_0,2,2); + + // c[2, 48-63] + CVT_MULRND_CVT32_CVT8(c_int32_2p3,a_int32_1,2,3); + + // c[3, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_3p0,selector1,3,0); + + // c[3, 16-31] + CVT_MULRND_CVT32_CVT8(c_int32_3p1,selector2,3,1); + + // c[3, 32-47] + CVT_MULRND_CVT32_CVT8(c_int32_3p2,a_int32_0,3,2); + + // c[3, 48-63] + CVT_MULRND_CVT32_CVT8(c_int32_3p3,a_int32_1,3,3); + + // c[4, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_4p0,selector1,4,0); + + // c[4, 16-31] + CVT_MULRND_CVT32_CVT8(c_int32_4p1,selector2,4,1); + + // c[4, 32-47] + CVT_MULRND_CVT32_CVT8(c_int32_4p2,a_int32_0,4,2); + + // c[4, 48-63] + CVT_MULRND_CVT32_CVT8(c_int32_4p3,a_int32_1,4,3); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_5x64_DISABLE: + ; + + // Store the results. + // c[0,0-15] + _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 0*16 ), c_int32_0p0 ); + + // c[0, 16-31] + _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 1*16 ), c_int32_0p1 ); + + // c[0,32-47] + _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 2*16 ), c_int32_0p2 ); + + // c[0,48-63] + _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 3*16 ), c_int32_0p3 ); + + // c[1,0-15] + _mm512_storeu_epi32( c + ( rs_c * 1 ) + ( 0*16 ), c_int32_1p0 ); + + // c[1,16-31] + _mm512_storeu_epi32( c + ( rs_c * 1 ) + ( 1*16 ), c_int32_1p1 ); + + // c[1,32-47] + _mm512_storeu_epi32( c + ( rs_c * 1 ) + ( 2*16 ), c_int32_1p2 ); + + // c[1,48-63] + _mm512_storeu_epi32( c + ( rs_c * 1 ) + ( 3*16 ), c_int32_1p3 ); + + // c[2,0-15] + _mm512_storeu_epi32( c + ( rs_c * 2 ) + ( 0*16 ), c_int32_2p0 ); + + // c[2,16-31] + _mm512_storeu_epi32( c + ( rs_c * 2 ) + ( 1*16 ), c_int32_2p1 ); + + // c[2,32-47] + _mm512_storeu_epi32( c + ( rs_c * 2 ) + ( 2*16 ), c_int32_2p2 ); + + // c[2,48-63] + _mm512_storeu_epi32( c + ( rs_c * 2 ) + ( 3*16 ), c_int32_2p3 ); + + // c[3,0-15] + _mm512_storeu_epi32( c + ( rs_c * 3 ) + ( 0*16 ), c_int32_3p0 ); + + // c[3,16-31] + _mm512_storeu_epi32( c + ( rs_c * 3 ) + ( 1*16 ), c_int32_3p1 ); + + // c[3,32-47] + _mm512_storeu_epi32( c + ( rs_c * 3 ) + ( 2*16 ), c_int32_3p2 ); + + // c[3,48-63] + _mm512_storeu_epi32( c + ( rs_c * 3 ) + ( 3*16 ), c_int32_3p3 ); + + // c[4,0-15] + _mm512_storeu_epi32( c + ( rs_c * 4 ) + ( 0*16 ), c_int32_4p0 ); + + // c[4,16-31] + _mm512_storeu_epi32( c + ( rs_c * 4 ) + ( 1*16 ), c_int32_4p1 ); + + // c[4,32-47] + _mm512_storeu_epi32( c + ( rs_c * 4 ) + ( 2*16 ), c_int32_4p2 ); + + // c[4,48-63] + _mm512_storeu_epi32( c + ( rs_c * 4 ) + ( 3*16 ), c_int32_4p3 ); +} + +// 4x64 int8o32 kernel +LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_4x64) +{ + static void* post_ops_labels[] = + { + &&POST_OPS_4x64_DISABLE, + &&POST_OPS_BIAS_4x64, + &&POST_OPS_RELU_4x64, + &&POST_OPS_RELU_SCALE_4x64, + &&POST_OPS_DOWNSCALE_4x64 + }; + dim_t k_full_pieces = k0 / 4; + dim_t k_partial_pieces = k0 % 4; + + uint32_t a_kfringe_buf = 0; + + // B matrix storage. + __m512i b0; + __m512i b1; + __m512i b2; + __m512i b3; + + // A matrix storage. + __m512i a_int32_0; + __m512i a_int32_1; + + // Registers to use for accumulating C. + __m512i c_int32_0p0 = _mm512_setzero_epi32(); + __m512i c_int32_0p1 = _mm512_setzero_epi32(); + __m512i c_int32_0p2 = _mm512_setzero_epi32(); + __m512i c_int32_0p3 = _mm512_setzero_epi32(); + + __m512i c_int32_1p0 = _mm512_setzero_epi32(); + __m512i c_int32_1p1 = _mm512_setzero_epi32(); + __m512i c_int32_1p2 = _mm512_setzero_epi32(); + __m512i c_int32_1p3 = _mm512_setzero_epi32(); + + __m512i c_int32_2p0 = _mm512_setzero_epi32(); + __m512i c_int32_2p1 = _mm512_setzero_epi32(); + __m512i c_int32_2p2 = _mm512_setzero_epi32(); + __m512i c_int32_2p3 = _mm512_setzero_epi32(); + + __m512i c_int32_3p0 = _mm512_setzero_epi32(); + __m512i c_int32_3p1 = _mm512_setzero_epi32(); + __m512i c_int32_3p2 = _mm512_setzero_epi32(); + __m512i c_int32_3p3 = _mm512_setzero_epi32(); + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + b0 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + b1 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 1 ) ); + b2 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 2 ) ); + b3 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 3 ) ); + + // Perform column direction mat-mul with k = 4. + // c[0,0-63] = a[0,kr:kr+4]*b[kr:kr+4,0-63] + c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); + + // Broadcast a[1,kr:kr+4]. + a_int32_1 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + + c_int32_0p1 = _mm512_dpbusd_epi32( c_int32_0p1, a_int32_0, b1 ); + c_int32_0p2 = _mm512_dpbusd_epi32( c_int32_0p2, a_int32_0, b2 ); + c_int32_0p3 = _mm512_dpbusd_epi32( c_int32_0p3, a_int32_0, b3 ); + + // Perform column direction mat-mul with k = 4. + // c[1,0-63] = a[1,kr:kr+4]*b[kr:kr+4,0-63] + c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_1, b0 ); + + // Broadcast a[2,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); + + c_int32_1p1 = _mm512_dpbusd_epi32( c_int32_1p1, a_int32_1, b1 ); + c_int32_1p2 = _mm512_dpbusd_epi32( c_int32_1p2, a_int32_1, b2 ); + c_int32_1p3 = _mm512_dpbusd_epi32( c_int32_1p3, a_int32_1, b3 ); + + // Perform column direction mat-mul with k = 4. + // c[2,0-63] = a[2,kr:kr+4]*b[kr:kr+4,0-63] + c_int32_2p0 = _mm512_dpbusd_epi32( c_int32_2p0, a_int32_0, b0 ); + + // Broadcast a[3,kr:kr+4]. + a_int32_1 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 3 ) + ( cs_a * kr ) ) ); + + c_int32_2p1 = _mm512_dpbusd_epi32( c_int32_2p1, a_int32_0, b1 ); + c_int32_2p2 = _mm512_dpbusd_epi32( c_int32_2p2, a_int32_0, b2 ); + c_int32_2p3 = _mm512_dpbusd_epi32( c_int32_2p3, a_int32_0, b3 ); + + // Perform column direction mat-mul with k = 4. + // c[3,0-63] = a[3,kr:kr+4]*b[kr:kr+4,0-63] + c_int32_3p0 = _mm512_dpbusd_epi32( c_int32_3p0, a_int32_1, b0 ); + c_int32_3p1 = _mm512_dpbusd_epi32( c_int32_3p1, a_int32_1, b1 ); + c_int32_3p2 = _mm512_dpbusd_epi32( c_int32_3p2, a_int32_1, b2 ); + c_int32_3p3 = _mm512_dpbusd_epi32( c_int32_3p3, a_int32_1, b3 ); + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + b0 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+4]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( uint8_t ) ) + ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + b1 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); + b2 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 2 ) ); + b3 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 3 ) ); + + // Perform column direction mat-mul with k = 4. + // c[0,0-63] = a[0,kr:kr+4]*b[kr:kr+4,0-63] + c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); + + // Broadcast a[1,kr:kr+4]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( uint8_t ) ) + ); + a_int32_1 = _mm512_set1_epi32( a_kfringe_buf ); + + c_int32_0p1 = _mm512_dpbusd_epi32( c_int32_0p1, a_int32_0, b1 ); + c_int32_0p2 = _mm512_dpbusd_epi32( c_int32_0p2, a_int32_0, b2 ); + c_int32_0p3 = _mm512_dpbusd_epi32( c_int32_0p3, a_int32_0, b3 ); + + // Perform column direction mat-mul with k = 4. + // c[1,0-63] = a[1,kr:kr+4]*b[kr:kr+4,0-63] + c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_1, b0 ); + + // Broadcast a[2,kr:kr+4]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 2 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( uint8_t ) ) + ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + c_int32_1p1 = _mm512_dpbusd_epi32( c_int32_1p1, a_int32_1, b1 ); + c_int32_1p2 = _mm512_dpbusd_epi32( c_int32_1p2, a_int32_1, b2 ); + c_int32_1p3 = _mm512_dpbusd_epi32( c_int32_1p3, a_int32_1, b3 ); + + // Perform column direction mat-mul with k = 4. + // c[2,0-63] = a[2,kr:kr+4]*b[kr:kr+4,0-63] + c_int32_2p0 = _mm512_dpbusd_epi32( c_int32_2p0, a_int32_0, b0 ); + + // Broadcast a[3,kr:kr+4]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 3 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( uint8_t ) ) + ); + a_int32_1 = _mm512_set1_epi32( a_kfringe_buf ); + + c_int32_2p1 = _mm512_dpbusd_epi32( c_int32_2p1, a_int32_0, b1 ); + c_int32_2p2 = _mm512_dpbusd_epi32( c_int32_2p2, a_int32_0, b2 ); + c_int32_2p3 = _mm512_dpbusd_epi32( c_int32_2p3, a_int32_0, b3 ); + + // Perform column direction mat-mul with k = 4. + // c[3,0-63] = a[3,kr:kr+4]*b[kr:kr+4,0-63] + c_int32_3p0 = _mm512_dpbusd_epi32( c_int32_3p0, a_int32_1, b0 ); + c_int32_3p1 = _mm512_dpbusd_epi32( c_int32_3p1, a_int32_1, b1 ); + c_int32_3p2 = _mm512_dpbusd_epi32( c_int32_3p2, a_int32_1, b2 ); + c_int32_3p3 = _mm512_dpbusd_epi32( c_int32_3p3, a_int32_1, b3 ); + } + + // Load alpha and beta + __m512i selector1 = _mm512_set1_epi32( alpha ); + __m512i selector2 = _mm512_set1_epi32( beta ); + + // Scale by alpha + c_int32_0p0 = _mm512_mullo_epi32( selector1, c_int32_0p0 ); + c_int32_0p1 = _mm512_mullo_epi32( selector1, c_int32_0p1 ); + c_int32_0p2 = _mm512_mullo_epi32( selector1, c_int32_0p2 ); + c_int32_0p3 = _mm512_mullo_epi32( selector1, c_int32_0p3 ); + + c_int32_1p0 = _mm512_mullo_epi32( selector1, c_int32_1p0 ); + c_int32_1p1 = _mm512_mullo_epi32( selector1, c_int32_1p1 ); + c_int32_1p2 = _mm512_mullo_epi32( selector1, c_int32_1p2 ); + c_int32_1p3 = _mm512_mullo_epi32( selector1, c_int32_1p3 ); + + c_int32_2p0 = _mm512_mullo_epi32( selector1, c_int32_2p0 ); + c_int32_2p1 = _mm512_mullo_epi32( selector1, c_int32_2p1 ); + c_int32_2p2 = _mm512_mullo_epi32( selector1, c_int32_2p2 ); + c_int32_2p3 = _mm512_mullo_epi32( selector1, c_int32_2p3 ); + + c_int32_3p0 = _mm512_mullo_epi32( selector1, c_int32_3p0 ); + c_int32_3p1 = _mm512_mullo_epi32( selector1, c_int32_3p1 ); + c_int32_3p2 = _mm512_mullo_epi32( selector1, c_int32_3p2 ); + c_int32_3p3 = _mm512_mullo_epi32( selector1, c_int32_3p3 ); + + // Scale C by beta. + if ( beta != 0 ) + { + // c[0,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); + + // c[0, 16-31] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 1*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_0p1 = _mm512_add_epi32( selector1, c_int32_0p1 ); + + // c[0,32-47] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 2*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_0p2 = _mm512_add_epi32( selector1, c_int32_0p2 ); + + // c[0,48-63] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 3*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_0p3 = _mm512_add_epi32( selector1, c_int32_0p3 ); + + // c[1,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 1 ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); + + // c[1,16-31] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 1 ) + ( 1*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_1p1 = _mm512_add_epi32( selector1, c_int32_1p1 ); + + // c[1,32-47] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 1 ) + ( 2*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_1p2 = _mm512_add_epi32( selector1, c_int32_1p2 ); + + // c[1,48-63] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 1 ) + ( 3*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_1p3 = _mm512_add_epi32( selector1, c_int32_1p3 ); + + // c[2,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 2 ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_2p0 = _mm512_add_epi32( selector1, c_int32_2p0 ); + + // c[2,16-31] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 2 ) + ( 1*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_2p1 = _mm512_add_epi32( selector1, c_int32_2p1 ); + + // c[2,32-47] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 2 ) + ( 2*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_2p2 = _mm512_add_epi32( selector1, c_int32_2p2 ); + + // c[2,48-63] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 2 ) + ( 3*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_2p3 = _mm512_add_epi32( selector1, c_int32_2p3 ); + + // c[3,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 3 ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_3p0 = _mm512_add_epi32( selector1, c_int32_3p0 ); + + // c[3,16-31] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 3 ) + ( 1*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_3p1 = _mm512_add_epi32( selector1, c_int32_3p1 ); + + // c[3,32-47] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 3 ) + ( 2*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_3p2 = _mm512_add_epi32( selector1, c_int32_3p2 ); + + // c[3,48-63] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 3 ) + ( 3*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_3p3 = _mm512_add_epi32( selector1, c_int32_3p3 ); + } + + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_4x64: + { + selector1 = + _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + + post_op_c_j ); + selector2 = + _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 1 * 16 ) ); + a_int32_0 = + _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 2 * 16 ) ); + a_int32_1 = + _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 3 * 16 ) ); + + // c[0,0-15] + c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); + + // c[0, 16-31] + c_int32_0p1 = _mm512_add_epi32( selector2, c_int32_0p1 ); + + // c[0,32-47] + c_int32_0p2 = _mm512_add_epi32( a_int32_0, c_int32_0p2 ); + + // c[0,48-63] + c_int32_0p3 = _mm512_add_epi32( a_int32_1, c_int32_0p3 ); + + // c[1,0-15] + c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); + + // c[1, 16-31] + c_int32_1p1 = _mm512_add_epi32( selector2, c_int32_1p1 ); + + // c[1,32-47] + c_int32_1p2 = _mm512_add_epi32( a_int32_0, c_int32_1p2 ); + + // c[1,48-63] + c_int32_1p3 = _mm512_add_epi32( a_int32_1, c_int32_1p3 ); + + // c[2,0-15] + c_int32_2p0 = _mm512_add_epi32( selector1, c_int32_2p0 ); + + // c[2, 16-31] + c_int32_2p1 = _mm512_add_epi32( selector2, c_int32_2p1 ); + + // c[2,32-47] + c_int32_2p2 = _mm512_add_epi32( a_int32_0, c_int32_2p2 ); + + // c[2,48-63] + c_int32_2p3 = _mm512_add_epi32( a_int32_1, c_int32_2p3 ); + + // c[3,0-15] + c_int32_3p0 = _mm512_add_epi32( selector1, c_int32_3p0 ); + + // c[3, 16-31] + c_int32_3p1 = _mm512_add_epi32( selector2, c_int32_3p1 ); + + // c[3,32-47] + c_int32_3p2 = _mm512_add_epi32( a_int32_0, c_int32_3p2 ); + + // c[3,48-63] + c_int32_3p3 = _mm512_add_epi32( a_int32_1, c_int32_3p3 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_4x64: + { + selector1 = _mm512_setzero_epi32(); + + // c[0,0-15] + c_int32_0p0 = _mm512_max_epi32( selector1, c_int32_0p0 ); + + // c[0, 16-31] + c_int32_0p1 = _mm512_max_epi32( selector1, c_int32_0p1 ); + + // c[0,32-47] + c_int32_0p2 = _mm512_max_epi32( selector1, c_int32_0p2 ); + + // c[0,48-63] + c_int32_0p3 = _mm512_max_epi32( selector1, c_int32_0p3 ); + + // c[1,0-15] + c_int32_1p0 = _mm512_max_epi32( selector1, c_int32_1p0 ); + + // c[1,16-31] + c_int32_1p1 = _mm512_max_epi32( selector1, c_int32_1p1 ); + + // c[1,32-47] + c_int32_1p2 = _mm512_max_epi32( selector1, c_int32_1p2 ); + + // c[1,48-63] + c_int32_1p3 = _mm512_max_epi32( selector1, c_int32_1p3 ); + + // c[2,0-15] + c_int32_2p0 = _mm512_max_epi32( selector1, c_int32_2p0 ); + + // c[2,16-31] + c_int32_2p1 = _mm512_max_epi32( selector1, c_int32_2p1 ); + + // c[2,32-47] + c_int32_2p2 = _mm512_max_epi32( selector1, c_int32_2p2 ); + + // c[2,48-63] + c_int32_2p3 = _mm512_max_epi32( selector1, c_int32_2p3 ); + + // c[3,0-15] + c_int32_3p0 = _mm512_max_epi32( selector1, c_int32_3p0 ); + + // c[3,16-31] + c_int32_3p1 = _mm512_max_epi32( selector1, c_int32_3p1 ); + + // c[3,32-47] + c_int32_3p2 = _mm512_max_epi32( selector1, c_int32_3p2 ); + + // c[3,48-63] + c_int32_3p3 = _mm512_max_epi32( selector1, c_int32_3p3 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_4x64: + { + selector1 = _mm512_setzero_epi32(); + selector2 = + _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_0p0) + + // c[0, 16-31] + RELU_SCALE_OP_S32_AVX512(c_int32_0p1) + + // c[0, 32-47] + RELU_SCALE_OP_S32_AVX512(c_int32_0p2) + + // c[0, 48-63] + RELU_SCALE_OP_S32_AVX512(c_int32_0p3) + + // c[1, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_1p0) + + // c[1, 16-31] + RELU_SCALE_OP_S32_AVX512(c_int32_1p1) + + // c[1, 32-47] + RELU_SCALE_OP_S32_AVX512(c_int32_1p2) + + // c[1, 48-63] + RELU_SCALE_OP_S32_AVX512(c_int32_1p3) + + // c[2, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_2p0) + + // c[2, 16-31] + RELU_SCALE_OP_S32_AVX512(c_int32_2p1) + + // c[2, 32-47] + RELU_SCALE_OP_S32_AVX512(c_int32_2p2) + + // c[2, 48-63] + RELU_SCALE_OP_S32_AVX512(c_int32_2p3) + + // c[3, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_3p0) + + // c[3, 16-31] + RELU_SCALE_OP_S32_AVX512(c_int32_3p1) + + // c[3, 32-47] + RELU_SCALE_OP_S32_AVX512(c_int32_3p2) + + // c[3, 48-63] + RELU_SCALE_OP_S32_AVX512(c_int32_3p3) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_DOWNSCALE_4x64: + { + selector1 = + _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j + ( 1 * 16 ) ); + a_int32_0 = + _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j + ( 2 * 16 ) ); + a_int32_1 = + _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j + ( 3 * 16 ) ); + + // c[0, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_0p0,selector1,0,0); + + // c[0, 16-31] + CVT_MULRND_CVT32_CVT8(c_int32_0p1,selector2,0,1); + + // c[0, 32-47] + CVT_MULRND_CVT32_CVT8(c_int32_0p2,a_int32_0,0,2); + + // c[0, 48-63] + CVT_MULRND_CVT32_CVT8(c_int32_0p3,a_int32_1,0,3); + + // c[1, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_1p0,selector1,1,0); + + // c[1, 16-31] + CVT_MULRND_CVT32_CVT8(c_int32_1p1,selector2,1,1); + + // c[1, 32-47] + CVT_MULRND_CVT32_CVT8(c_int32_1p2,a_int32_0,1,2); + + // c[1, 48-63] + CVT_MULRND_CVT32_CVT8(c_int32_1p3,a_int32_1,1,3); + + // c[2, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_2p0,selector1,2,0); + + // c[2, 16-31] + CVT_MULRND_CVT32_CVT8(c_int32_2p1,selector2,2,1); + + // c[2, 32-47] + CVT_MULRND_CVT32_CVT8(c_int32_2p2,a_int32_0,2,2); + + // c[2, 48-63] + CVT_MULRND_CVT32_CVT8(c_int32_2p3,a_int32_1,2,3); + + // c[3, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_3p0,selector1,3,0); + + // c[3, 16-31] + CVT_MULRND_CVT32_CVT8(c_int32_3p1,selector2,3,1); + + // c[3, 32-47] + CVT_MULRND_CVT32_CVT8(c_int32_3p2,a_int32_0,3,2); + + // c[3, 48-63] + CVT_MULRND_CVT32_CVT8(c_int32_3p3,a_int32_1,3,3); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_4x64_DISABLE: + ; + + // Store the results. + // c[0,0-15] + _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 0*16 ), c_int32_0p0 ); + + // c[0, 16-31] + _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 1*16 ), c_int32_0p1 ); + + // c[0,32-47] + _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 2*16 ), c_int32_0p2 ); + + // c[0,48-63] + _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 3*16 ), c_int32_0p3 ); + + // c[1,0-15] + _mm512_storeu_epi32( c + ( rs_c * 1 ) + ( 0*16 ), c_int32_1p0 ); + + // c[1,16-31] + _mm512_storeu_epi32( c + ( rs_c * 1 ) + ( 1*16 ), c_int32_1p1 ); + + // c[1,32-47] + _mm512_storeu_epi32( c + ( rs_c * 1 ) + ( 2*16 ), c_int32_1p2 ); + + // c[1,48-63] + _mm512_storeu_epi32( c + ( rs_c * 1 ) + ( 3*16 ), c_int32_1p3 ); + + // c[2,0-15] + _mm512_storeu_epi32( c + ( rs_c * 2 ) + ( 0*16 ), c_int32_2p0 ); + + // c[2,16-31] + _mm512_storeu_epi32( c + ( rs_c * 2 ) + ( 1*16 ), c_int32_2p1 ); + + // c[2,32-47] + _mm512_storeu_epi32( c + ( rs_c * 2 ) + ( 2*16 ), c_int32_2p2 ); + + // c[2,48-63] + _mm512_storeu_epi32( c + ( rs_c * 2 ) + ( 3*16 ), c_int32_2p3 ); + + // c[3,0-15] + _mm512_storeu_epi32( c + ( rs_c * 3 ) + ( 0*16 ), c_int32_3p0 ); + + // c[3,16-31] + _mm512_storeu_epi32( c + ( rs_c * 3 ) + ( 1*16 ), c_int32_3p1 ); + + // c[3,32-47] + _mm512_storeu_epi32( c + ( rs_c * 3 ) + ( 2*16 ), c_int32_3p2 ); + + // c[3,48-63] + _mm512_storeu_epi32( c + ( rs_c * 3 ) + ( 3*16 ), c_int32_3p3 ); +} + +// 3x64 int8o32 kernel +LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_3x64) +{ + static void* post_ops_labels[] = + { + &&POST_OPS_3x64_DISABLE, + &&POST_OPS_BIAS_3x64, + &&POST_OPS_RELU_3x64, + &&POST_OPS_RELU_SCALE_3x64, + &&POST_OPS_DOWNSCALE_3x64 + }; + dim_t k_full_pieces = k0 / 4; + dim_t k_partial_pieces = k0 % 4; + + uint32_t a_kfringe_buf = 0; + + // B matrix storage. + __m512i b0; + __m512i b1; + __m512i b2; + __m512i b3; + + // A matrix storage. + __m512i a_int32_0; + __m512i a_int32_1; + + // Registers to use for accumulating C. + __m512i c_int32_0p0 = _mm512_setzero_epi32(); + __m512i c_int32_0p1 = _mm512_setzero_epi32(); + __m512i c_int32_0p2 = _mm512_setzero_epi32(); + __m512i c_int32_0p3 = _mm512_setzero_epi32(); + + __m512i c_int32_1p0 = _mm512_setzero_epi32(); + __m512i c_int32_1p1 = _mm512_setzero_epi32(); + __m512i c_int32_1p2 = _mm512_setzero_epi32(); + __m512i c_int32_1p3 = _mm512_setzero_epi32(); + + __m512i c_int32_2p0 = _mm512_setzero_epi32(); + __m512i c_int32_2p1 = _mm512_setzero_epi32(); + __m512i c_int32_2p2 = _mm512_setzero_epi32(); + __m512i c_int32_2p3 = _mm512_setzero_epi32(); + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + b0 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + b1 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 1 ) ); + b2 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 2 ) ); + b3 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 3 ) ); + + // Perform column direction mat-mul with k = 4. + // c[0,0-63] = a[0,kr:kr+4]*b[kr:kr+4,0-63] + c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); + + // Broadcast a[1,kr:kr+4]. + a_int32_1 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + + c_int32_0p1 = _mm512_dpbusd_epi32( c_int32_0p1, a_int32_0, b1 ); + c_int32_0p2 = _mm512_dpbusd_epi32( c_int32_0p2, a_int32_0, b2 ); + c_int32_0p3 = _mm512_dpbusd_epi32( c_int32_0p3, a_int32_0, b3 ); + + // Perform column direction mat-mul with k = 4. + // c[1,0-63] = a[1,kr:kr+4]*b[kr:kr+4,0-63] + c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_1, b0 ); + + // Broadcast a[2,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); + + c_int32_1p1 = _mm512_dpbusd_epi32( c_int32_1p1, a_int32_1, b1 ); + c_int32_1p2 = _mm512_dpbusd_epi32( c_int32_1p2, a_int32_1, b2 ); + c_int32_1p3 = _mm512_dpbusd_epi32( c_int32_1p3, a_int32_1, b3 ); + + // Perform column direction mat-mul with k = 4. + // c[2,0-63] = a[2,kr:kr+4]*b[kr:kr+4,0-63] + c_int32_2p0 = _mm512_dpbusd_epi32( c_int32_2p0, a_int32_0, b0 ); + c_int32_2p1 = _mm512_dpbusd_epi32( c_int32_2p1, a_int32_0, b1 ); + c_int32_2p2 = _mm512_dpbusd_epi32( c_int32_2p2, a_int32_0, b2 ); + c_int32_2p3 = _mm512_dpbusd_epi32( c_int32_2p3, a_int32_0, b3 ); + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + b0 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+4]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( uint8_t ) ) + ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + b1 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); + b2 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 2 ) ); + b3 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 3 ) ); + + // Perform column direction mat-mul with k = 4. + // c[0,0-63] = a[0,kr:kr+4]*b[kr:kr+4,0-63] + c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); + + // Broadcast a[1,kr:kr+4]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( uint8_t ) ) + ); + a_int32_1 = _mm512_set1_epi32( a_kfringe_buf ); + + c_int32_0p1 = _mm512_dpbusd_epi32( c_int32_0p1, a_int32_0, b1 ); + c_int32_0p2 = _mm512_dpbusd_epi32( c_int32_0p2, a_int32_0, b2 ); + c_int32_0p3 = _mm512_dpbusd_epi32( c_int32_0p3, a_int32_0, b3 ); + + // Perform column direction mat-mul with k = 4. + // c[1,0-63] = a[1,kr:kr+4]*b[kr:kr+4,0-63] + c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_1, b0 ); + + // Broadcast a[2,kr:kr+4]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 2 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( uint8_t ) ) + ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + c_int32_1p1 = _mm512_dpbusd_epi32( c_int32_1p1, a_int32_1, b1 ); + c_int32_1p2 = _mm512_dpbusd_epi32( c_int32_1p2, a_int32_1, b2 ); + c_int32_1p3 = _mm512_dpbusd_epi32( c_int32_1p3, a_int32_1, b3 ); + + // Perform column direction mat-mul with k = 4. + // c[2,0-63] = a[2,kr:kr+4]*b[kr:kr+4,0-63] + c_int32_2p0 = _mm512_dpbusd_epi32( c_int32_2p0, a_int32_0, b0 ); + c_int32_2p1 = _mm512_dpbusd_epi32( c_int32_2p1, a_int32_0, b1 ); + c_int32_2p2 = _mm512_dpbusd_epi32( c_int32_2p2, a_int32_0, b2 ); + c_int32_2p3 = _mm512_dpbusd_epi32( c_int32_2p3, a_int32_0, b3 ); + } + + // Load alpha and beta + __m512i selector1 = _mm512_set1_epi32( alpha ); + __m512i selector2 = _mm512_set1_epi32( beta ); + + // Scale by alpha + c_int32_0p0 = _mm512_mullo_epi32( selector1, c_int32_0p0 ); + c_int32_0p1 = _mm512_mullo_epi32( selector1, c_int32_0p1 ); + c_int32_0p2 = _mm512_mullo_epi32( selector1, c_int32_0p2 ); + c_int32_0p3 = _mm512_mullo_epi32( selector1, c_int32_0p3 ); + + c_int32_1p0 = _mm512_mullo_epi32( selector1, c_int32_1p0 ); + c_int32_1p1 = _mm512_mullo_epi32( selector1, c_int32_1p1 ); + c_int32_1p2 = _mm512_mullo_epi32( selector1, c_int32_1p2 ); + c_int32_1p3 = _mm512_mullo_epi32( selector1, c_int32_1p3 ); + + c_int32_2p0 = _mm512_mullo_epi32( selector1, c_int32_2p0 ); + c_int32_2p1 = _mm512_mullo_epi32( selector1, c_int32_2p1 ); + c_int32_2p2 = _mm512_mullo_epi32( selector1, c_int32_2p2 ); + c_int32_2p3 = _mm512_mullo_epi32( selector1, c_int32_2p3 ); + + // Scale C by beta. + if ( beta != 0 ) + { + // c[0,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); + + // c[0, 16-31] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 1*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_0p1 = _mm512_add_epi32( selector1, c_int32_0p1 ); + + // c[0,32-47] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 2*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_0p2 = _mm512_add_epi32( selector1, c_int32_0p2 ); + + // c[0,48-63] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 3*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_0p3 = _mm512_add_epi32( selector1, c_int32_0p3 ); + + // c[1,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 1 ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); + + // c[1,16-31] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 1 ) + ( 1*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_1p1 = _mm512_add_epi32( selector1, c_int32_1p1 ); + + // c[1,32-47] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 1 ) + ( 2*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_1p2 = _mm512_add_epi32( selector1, c_int32_1p2 ); + + // c[1,48-63] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 1 ) + ( 3*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_1p3 = _mm512_add_epi32( selector1, c_int32_1p3 ); + + // c[2,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 2 ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_2p0 = _mm512_add_epi32( selector1, c_int32_2p0 ); + + // c[2,16-31] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 2 ) + ( 1*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_2p1 = _mm512_add_epi32( selector1, c_int32_2p1 ); + + // c[2,32-47] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 2 ) + ( 2*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_2p2 = _mm512_add_epi32( selector1, c_int32_2p2 ); + + // c[2,48-63] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 2 ) + ( 3*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_2p3 = _mm512_add_epi32( selector1, c_int32_2p3 ); + } + + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_3x64: + { + selector1 = + _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + + post_op_c_j ); + selector2 = + _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 1 * 16 ) ); + a_int32_0 = + _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 2 * 16 ) ); + a_int32_1 = + _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 3 * 16 ) ); + + // c[0,0-15] + c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); + + // c[0, 16-31] + c_int32_0p1 = _mm512_add_epi32( selector2, c_int32_0p1 ); + + // c[0,32-47] + c_int32_0p2 = _mm512_add_epi32( a_int32_0, c_int32_0p2 ); + + // c[0,48-63] + c_int32_0p3 = _mm512_add_epi32( a_int32_1, c_int32_0p3 ); + + // c[1,0-15] + c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); + + // c[1, 16-31] + c_int32_1p1 = _mm512_add_epi32( selector2, c_int32_1p1 ); + + // c[1,32-47] + c_int32_1p2 = _mm512_add_epi32( a_int32_0, c_int32_1p2 ); + + // c[1,48-63] + c_int32_1p3 = _mm512_add_epi32( a_int32_1, c_int32_1p3 ); + + // c[2,0-15] + c_int32_2p0 = _mm512_add_epi32( selector1, c_int32_2p0 ); + + // c[2, 16-31] + c_int32_2p1 = _mm512_add_epi32( selector2, c_int32_2p1 ); + + // c[2,32-47] + c_int32_2p2 = _mm512_add_epi32( a_int32_0, c_int32_2p2 ); + + // c[2,48-63] + c_int32_2p3 = _mm512_add_epi32( a_int32_1, c_int32_2p3 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_3x64: + { + selector1 = _mm512_setzero_epi32(); + + // c[0,0-15] + c_int32_0p0 = _mm512_max_epi32( selector1, c_int32_0p0 ); + + // c[0, 16-31] + c_int32_0p1 = _mm512_max_epi32( selector1, c_int32_0p1 ); + + // c[0,32-47] + c_int32_0p2 = _mm512_max_epi32( selector1, c_int32_0p2 ); + + // c[0,48-63] + c_int32_0p3 = _mm512_max_epi32( selector1, c_int32_0p3 ); + + // c[1,0-15] + c_int32_1p0 = _mm512_max_epi32( selector1, c_int32_1p0 ); + + // c[1,16-31] + c_int32_1p1 = _mm512_max_epi32( selector1, c_int32_1p1 ); + + // c[1,32-47] + c_int32_1p2 = _mm512_max_epi32( selector1, c_int32_1p2 ); + + // c[1,48-63] + c_int32_1p3 = _mm512_max_epi32( selector1, c_int32_1p3 ); + + // c[2,0-15] + c_int32_2p0 = _mm512_max_epi32( selector1, c_int32_2p0 ); + + // c[2,16-31] + c_int32_2p1 = _mm512_max_epi32( selector1, c_int32_2p1 ); + + // c[2,32-47] + c_int32_2p2 = _mm512_max_epi32( selector1, c_int32_2p2 ); + + // c[2,48-63] + c_int32_2p3 = _mm512_max_epi32( selector1, c_int32_2p3 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_3x64: + { + selector1 = _mm512_setzero_epi32(); + selector2 = + _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_0p0) + + // c[0, 16-31] + RELU_SCALE_OP_S32_AVX512(c_int32_0p1) + + // c[0, 32-47] + RELU_SCALE_OP_S32_AVX512(c_int32_0p2) + + // c[0, 48-63] + RELU_SCALE_OP_S32_AVX512(c_int32_0p3) + + // c[1, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_1p0) + + // c[1, 16-31] + RELU_SCALE_OP_S32_AVX512(c_int32_1p1) + + // c[1, 32-47] + RELU_SCALE_OP_S32_AVX512(c_int32_1p2) + + // c[1, 48-63] + RELU_SCALE_OP_S32_AVX512(c_int32_1p3) + + // c[2, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_2p0) + + // c[2, 16-31] + RELU_SCALE_OP_S32_AVX512(c_int32_2p1) + + // c[2, 32-47] + RELU_SCALE_OP_S32_AVX512(c_int32_2p2) + + // c[2, 48-63] + RELU_SCALE_OP_S32_AVX512(c_int32_2p3) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_DOWNSCALE_3x64: + { + selector1 = + _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j + ( 1 * 16 ) ); + a_int32_0 = + _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j + ( 2 * 16 ) ); + a_int32_1 = + _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j + ( 3 * 16 ) ); + + // c[0, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_0p0,selector1,0,0); + + // c[0, 16-31] + CVT_MULRND_CVT32_CVT8(c_int32_0p1,selector2,0,1); + + // c[0, 32-47] + CVT_MULRND_CVT32_CVT8(c_int32_0p2,a_int32_0,0,2); + + // c[0, 48-63] + CVT_MULRND_CVT32_CVT8(c_int32_0p3,a_int32_1,0,3); + + // c[1, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_1p0,selector1,1,0); + + // c[1, 16-31] + CVT_MULRND_CVT32_CVT8(c_int32_1p1,selector2,1,1); + + // c[1, 32-47] + CVT_MULRND_CVT32_CVT8(c_int32_1p2,a_int32_0,1,2); + + // c[1, 48-63] + CVT_MULRND_CVT32_CVT8(c_int32_1p3,a_int32_1,1,3); + + // c[2, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_2p0,selector1,2,0); + + // c[2, 16-31] + CVT_MULRND_CVT32_CVT8(c_int32_2p1,selector2,2,1); + + // c[2, 32-47] + CVT_MULRND_CVT32_CVT8(c_int32_2p2,a_int32_0,2,2); + + // c[2, 48-63] + CVT_MULRND_CVT32_CVT8(c_int32_2p3,a_int32_1,2,3); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_3x64_DISABLE: + ; + + // Store the results. + // c[0,0-15] + _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 0*16 ), c_int32_0p0 ); + + // c[0, 16-31] + _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 1*16 ), c_int32_0p1 ); + + // c[0,32-47] + _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 2*16 ), c_int32_0p2 ); + + // c[0,48-63] + _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 3*16 ), c_int32_0p3 ); + + // c[1,0-15] + _mm512_storeu_epi32( c + ( rs_c * 1 ) + ( 0*16 ), c_int32_1p0 ); + + // c[1,16-31] + _mm512_storeu_epi32( c + ( rs_c * 1 ) + ( 1*16 ), c_int32_1p1 ); + + // c[1,32-47] + _mm512_storeu_epi32( c + ( rs_c * 1 ) + ( 2*16 ), c_int32_1p2 ); + + // c[1,48-63] + _mm512_storeu_epi32( c + ( rs_c * 1 ) + ( 3*16 ), c_int32_1p3 ); + + // c[2,0-15] + _mm512_storeu_epi32( c + ( rs_c * 2 ) + ( 0*16 ), c_int32_2p0 ); + + // c[2,16-31] + _mm512_storeu_epi32( c + ( rs_c * 2 ) + ( 1*16 ), c_int32_2p1 ); + + // c[2,32-47] + _mm512_storeu_epi32( c + ( rs_c * 2 ) + ( 2*16 ), c_int32_2p2 ); + + // c[2,48-63] + _mm512_storeu_epi32( c + ( rs_c * 2 ) + ( 3*16 ), c_int32_2p3 ); +} + +// 2x64 int8o32 kernel +LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_2x64) +{ + static void* post_ops_labels[] = + { + &&POST_OPS_2x64_DISABLE, + &&POST_OPS_BIAS_2x64, + &&POST_OPS_RELU_2x64, + &&POST_OPS_RELU_SCALE_2x64, + &&POST_OPS_DOWNSCALE_2x64 + }; + dim_t k_full_pieces = k0 / 4; + dim_t k_partial_pieces = k0 % 4; + + uint32_t a_kfringe_buf = 0; + + // B matrix storage. + __m512i b0; + __m512i b1; + __m512i b2; + __m512i b3; + + // A matrix storage. + __m512i a_int32_0; + __m512i a_int32_1; + + // Registers to use for accumulating C. + __m512i c_int32_0p0 = _mm512_setzero_epi32(); + __m512i c_int32_0p1 = _mm512_setzero_epi32(); + __m512i c_int32_0p2 = _mm512_setzero_epi32(); + __m512i c_int32_0p3 = _mm512_setzero_epi32(); + + __m512i c_int32_1p0 = _mm512_setzero_epi32(); + __m512i c_int32_1p1 = _mm512_setzero_epi32(); + __m512i c_int32_1p2 = _mm512_setzero_epi32(); + __m512i c_int32_1p3 = _mm512_setzero_epi32(); + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + b0 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + b1 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 1 ) ); + b2 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 2 ) ); + b3 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 3 ) ); + + // Perform column direction mat-mul with k = 4. + // c[0,0-63] = a[0,kr:kr+4]*b[kr:kr+4,0-63] + c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); + + // Broadcast a[1,kr:kr+4]. + a_int32_1 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + + c_int32_0p1 = _mm512_dpbusd_epi32( c_int32_0p1, a_int32_0, b1 ); + c_int32_0p2 = _mm512_dpbusd_epi32( c_int32_0p2, a_int32_0, b2 ); + c_int32_0p3 = _mm512_dpbusd_epi32( c_int32_0p3, a_int32_0, b3 ); + + // Perform column direction mat-mul with k = 4. + // c[1,0-63] = a[1,kr:kr+4]*b[kr:kr+4,0-63] + c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_1, b0 ); + c_int32_1p1 = _mm512_dpbusd_epi32( c_int32_1p1, a_int32_1, b1 ); + c_int32_1p2 = _mm512_dpbusd_epi32( c_int32_1p2, a_int32_1, b2 ); + c_int32_1p3 = _mm512_dpbusd_epi32( c_int32_1p3, a_int32_1, b3 ); + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + b0 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+4]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( uint8_t ) ) + ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + b1 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); + b2 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 2 ) ); + b3 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 3 ) ); + + // Perform column direction mat-mul with k = 4. + // c[0,0-63] = a[0,kr:kr+4]*b[kr:kr+4,0-63] + c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); + + // Broadcast a[1,kr:kr+4]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( uint8_t ) ) + ); + a_int32_1 = _mm512_set1_epi32( a_kfringe_buf ); + + c_int32_0p1 = _mm512_dpbusd_epi32( c_int32_0p1, a_int32_0, b1 ); + c_int32_0p2 = _mm512_dpbusd_epi32( c_int32_0p2, a_int32_0, b2 ); + c_int32_0p3 = _mm512_dpbusd_epi32( c_int32_0p3, a_int32_0, b3 ); + + // Perform column direction mat-mul with k = 4. + // c[1,0-63] = a[1,kr:kr+4]*b[kr:kr+4,0-63] + c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_1, b0 ); + c_int32_1p1 = _mm512_dpbusd_epi32( c_int32_1p1, a_int32_1, b1 ); + c_int32_1p2 = _mm512_dpbusd_epi32( c_int32_1p2, a_int32_1, b2 ); + c_int32_1p3 = _mm512_dpbusd_epi32( c_int32_1p3, a_int32_1, b3 ); + } + + // Load alpha and beta + __m512i selector1 = _mm512_set1_epi32( alpha ); + __m512i selector2 = _mm512_set1_epi32( beta ); + + // Scale by alpha + c_int32_0p0 = _mm512_mullo_epi32( selector1, c_int32_0p0 ); + c_int32_0p1 = _mm512_mullo_epi32( selector1, c_int32_0p1 ); + c_int32_0p2 = _mm512_mullo_epi32( selector1, c_int32_0p2 ); + c_int32_0p3 = _mm512_mullo_epi32( selector1, c_int32_0p3 ); + + c_int32_1p0 = _mm512_mullo_epi32( selector1, c_int32_1p0 ); + c_int32_1p1 = _mm512_mullo_epi32( selector1, c_int32_1p1 ); + c_int32_1p2 = _mm512_mullo_epi32( selector1, c_int32_1p2 ); + c_int32_1p3 = _mm512_mullo_epi32( selector1, c_int32_1p3 ); + + // Scale C by beta. + if ( beta != 0 ) + { + // c[0,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); + + // c[0, 16-31] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 1*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_0p1 = _mm512_add_epi32( selector1, c_int32_0p1 ); + + // c[0,32-47] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 2*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_0p2 = _mm512_add_epi32( selector1, c_int32_0p2 ); + + // c[0,48-63] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 3*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_0p3 = _mm512_add_epi32( selector1, c_int32_0p3 ); + + // c[1,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 1 ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); + + // c[1,16-31] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 1 ) + ( 1*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_1p1 = _mm512_add_epi32( selector1, c_int32_1p1 ); + + // c[1,32-47] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 1 ) + ( 2*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_1p2 = _mm512_add_epi32( selector1, c_int32_1p2 ); + + // c[1,48-63] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 1 ) + ( 3*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_1p3 = _mm512_add_epi32( selector1, c_int32_1p3 ); + } + + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_2x64: + { + selector1 = + _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + + post_op_c_j ); + selector2 = + _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 1 * 16 ) ); + a_int32_0 = + _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 2 * 16 ) ); + a_int32_1 = + _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 3 * 16 ) ); + + // c[0,0-15] + c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); + + // c[0, 16-31] + c_int32_0p1 = _mm512_add_epi32( selector2, c_int32_0p1 ); + + // c[0,32-47] + c_int32_0p2 = _mm512_add_epi32( a_int32_0, c_int32_0p2 ); + + // c[0,48-63] + c_int32_0p3 = _mm512_add_epi32( a_int32_1, c_int32_0p3 ); + + // c[1,0-15] + c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); + + // c[1, 16-31] + c_int32_1p1 = _mm512_add_epi32( selector2, c_int32_1p1 ); + + // c[1,32-47] + c_int32_1p2 = _mm512_add_epi32( a_int32_0, c_int32_1p2 ); + + // c[1,48-63] + c_int32_1p3 = _mm512_add_epi32( a_int32_1, c_int32_1p3 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_2x64: + { + selector1 = _mm512_setzero_epi32(); + + // c[0,0-15] + c_int32_0p0 = _mm512_max_epi32( selector1, c_int32_0p0 ); + + // c[0, 16-31] + c_int32_0p1 = _mm512_max_epi32( selector1, c_int32_0p1 ); + + // c[0,32-47] + c_int32_0p2 = _mm512_max_epi32( selector1, c_int32_0p2 ); + + // c[0,48-63] + c_int32_0p3 = _mm512_max_epi32( selector1, c_int32_0p3 ); + + // c[1,0-15] + c_int32_1p0 = _mm512_max_epi32( selector1, c_int32_1p0 ); + + // c[1,16-31] + c_int32_1p1 = _mm512_max_epi32( selector1, c_int32_1p1 ); + + // c[1,32-47] + c_int32_1p2 = _mm512_max_epi32( selector1, c_int32_1p2 ); + + // c[1,48-63] + c_int32_1p3 = _mm512_max_epi32( selector1, c_int32_1p3 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_2x64: + { + selector1 = _mm512_setzero_epi32(); + selector2 = + _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_0p0) + + // c[0, 16-31] + RELU_SCALE_OP_S32_AVX512(c_int32_0p1) + + // c[0, 32-47] + RELU_SCALE_OP_S32_AVX512(c_int32_0p2) + + // c[0, 48-63] + RELU_SCALE_OP_S32_AVX512(c_int32_0p3) + + // c[1, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_1p0) + + // c[1, 16-31] + RELU_SCALE_OP_S32_AVX512(c_int32_1p1) + + // c[1, 32-47] + RELU_SCALE_OP_S32_AVX512(c_int32_1p2) + + // c[1, 48-63] + RELU_SCALE_OP_S32_AVX512(c_int32_1p3) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_DOWNSCALE_2x64: + { + selector1 = + _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j + ( 1 * 16 ) ); + a_int32_0 = + _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j + ( 2 * 16 ) ); + a_int32_1 = + _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j + ( 3 * 16 ) ); + + // c[0, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_0p0,selector1,0,0); + + // c[0, 16-31] + CVT_MULRND_CVT32_CVT8(c_int32_0p1,selector2,0,1); + + // c[0, 32-47] + CVT_MULRND_CVT32_CVT8(c_int32_0p2,a_int32_0,0,2); + + // c[0, 48-63] + CVT_MULRND_CVT32_CVT8(c_int32_0p3,a_int32_1,0,3); + + // c[1, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_1p0,selector1,1,0); + + // c[1, 16-31] + CVT_MULRND_CVT32_CVT8(c_int32_1p1,selector2,1,1); + + // c[1, 32-47] + CVT_MULRND_CVT32_CVT8(c_int32_1p2,a_int32_0,1,2); + + // c[1, 48-63] + CVT_MULRND_CVT32_CVT8(c_int32_1p3,a_int32_1,1,3); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_2x64_DISABLE: + ; + + // Store the results. + // c[0,0-15] + _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 0*16 ), c_int32_0p0 ); + + // c[0, 16-31] + _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 1*16 ), c_int32_0p1 ); + + // c[0,32-47] + _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 2*16 ), c_int32_0p2 ); + + // c[0,48-63] + _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 3*16 ), c_int32_0p3 ); + + // c[1,0-15] + _mm512_storeu_epi32( c + ( rs_c * 1 ) + ( 0*16 ), c_int32_1p0 ); + + // c[1,16-31] + _mm512_storeu_epi32( c + ( rs_c * 1 ) + ( 1*16 ), c_int32_1p1 ); + + // c[1,32-47] + _mm512_storeu_epi32( c + ( rs_c * 1 ) + ( 2*16 ), c_int32_1p2 ); + + // c[1,48-63] + _mm512_storeu_epi32( c + ( rs_c * 1 ) + ( 3*16 ), c_int32_1p3 ); +} + +// 1x64 int8o32 kernel +LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_1x64) +{ + static void* post_ops_labels[] = + { + &&POST_OPS_1x64_DISABLE, + &&POST_OPS_BIAS_1x64, + &&POST_OPS_RELU_1x64, + &&POST_OPS_RELU_SCALE_1x64, + &&POST_OPS_DOWNSCALE_1x64 + }; + dim_t k_full_pieces = k0 / 4; + dim_t k_partial_pieces = k0 % 4; + + uint32_t a_kfringe_buf = 0; + + // B matrix storage. + __m512i b0; + __m512i b1; + __m512i b2; + __m512i b3; + + // A matrix storage. + __m512i a_int32_0; + __m512i a_int32_1; + + // Registers to use for accumulating C. + __m512i c_int32_0p0 = _mm512_setzero_epi32(); + __m512i c_int32_0p1 = _mm512_setzero_epi32(); + __m512i c_int32_0p2 = _mm512_setzero_epi32(); + __m512i c_int32_0p3 = _mm512_setzero_epi32(); + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + b0 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr] + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + b1 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 1 ) ); + b2 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 2 ) ); + b3 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 3 ) ); + + // Perform column direction mat-mul with k = 4. + // c[0,0-63] = a[0,kr:kr+4]*b[kr:kr+4,0-63] + c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); + c_int32_0p1 = _mm512_dpbusd_epi32( c_int32_0p1, a_int32_0, b1 ); + c_int32_0p2 = _mm512_dpbusd_epi32( c_int32_0p2, a_int32_0, b2 ); + c_int32_0p3 = _mm512_dpbusd_epi32( c_int32_0p3, a_int32_0, b3 ); + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + b0 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+4]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( uint8_t ) ) + ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + b1 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); + b2 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 2 ) ); + b3 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 3 ) ); + + // Perform column direction mat-mul with k = 4. + // c[0,0-63] = a[0,kr:kr+4]*b[kr:kr+4,0-63] + c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); + c_int32_0p1 = _mm512_dpbusd_epi32( c_int32_0p1, a_int32_0, b1 ); + c_int32_0p2 = _mm512_dpbusd_epi32( c_int32_0p2, a_int32_0, b2 ); + c_int32_0p3 = _mm512_dpbusd_epi32( c_int32_0p3, a_int32_0, b3 ); + } + + // Load alpha and beta + __m512i selector1 = _mm512_set1_epi32( alpha ); + __m512i selector2 = _mm512_set1_epi32( beta ); + + // Scale by alpha + c_int32_0p0 = _mm512_mullo_epi32( selector1, c_int32_0p0 ); + c_int32_0p1 = _mm512_mullo_epi32( selector1, c_int32_0p1 ); + c_int32_0p2 = _mm512_mullo_epi32( selector1, c_int32_0p2 ); + c_int32_0p3 = _mm512_mullo_epi32( selector1, c_int32_0p3 ); + + // Scale C by beta. + if ( beta != 0) + { + // c[0,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); + + // c[0, 16-31] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 1*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_0p1 = _mm512_add_epi32( selector1, c_int32_0p1 ); + + // c[0,32-47] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 2*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_0p2 = _mm512_add_epi32( selector1, c_int32_0p2 ); + + // c[0,48-63] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 3*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_0p3 = _mm512_add_epi32( selector1, c_int32_0p3 ); + } + + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_1x64: + { + selector1 = + _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + + post_op_c_j ); + selector2 = + _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 1 * 16 ) ); + a_int32_0 = + _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 2 * 16 ) ); + a_int32_1 = + _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 3 * 16 ) ); + + // c[0,0-15] + c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); + + // c[0, 16-31] + c_int32_0p1 = _mm512_add_epi32( selector2, c_int32_0p1 ); + + // c[0,32-47] + c_int32_0p2 = _mm512_add_epi32( a_int32_0, c_int32_0p2 ); + + // c[0,48-63] + c_int32_0p3 = _mm512_add_epi32( a_int32_1, c_int32_0p3 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_1x64: + { + selector1 = _mm512_setzero_epi32(); + + // c[0,0-15] + c_int32_0p0 = _mm512_max_epi32( selector1, c_int32_0p0 ); + + // c[0, 16-31] + c_int32_0p1 = _mm512_max_epi32( selector1, c_int32_0p1 ); + + // c[0,32-47] + c_int32_0p2 = _mm512_max_epi32( selector1, c_int32_0p2 ); + + // c[0,48-63] + c_int32_0p3 = _mm512_max_epi32( selector1, c_int32_0p3 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_1x64: + { + selector1 = _mm512_setzero_epi32(); + selector2 = + _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_0p0) + + // c[0, 16-31] + RELU_SCALE_OP_S32_AVX512(c_int32_0p1) + + // c[0, 32-47] + RELU_SCALE_OP_S32_AVX512(c_int32_0p2) + + // c[0, 48-63] + RELU_SCALE_OP_S32_AVX512(c_int32_0p3) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_DOWNSCALE_1x64: + { + selector1 = + _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j + ( 1 * 16 ) ); + a_int32_0 = + _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j + ( 2 * 16 ) ); + a_int32_1 = + _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j + ( 3 * 16 ) ); + + // c[0, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_0p0,selector1,0,0); + + // c[0, 16-31] + CVT_MULRND_CVT32_CVT8(c_int32_0p1,selector2,0,1); + + // c[0, 32-47] + CVT_MULRND_CVT32_CVT8(c_int32_0p2,a_int32_0,0,2); + + // c[0, 48-63] + CVT_MULRND_CVT32_CVT8(c_int32_0p3,a_int32_1,0,3); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_1x64_DISABLE: + ; + + // Store the accumulated results. + // c[0,0-15] + _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 0*16 ), c_int32_0p0 ); + + // c[0, 16-31] + _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 1*16 ), c_int32_0p1 ); + + // c[0,32-47] + _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 2*16 ), c_int32_0p2 ); + + // c[0,48-63] + _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 3*16 ), c_int32_0p3 ); +} +#endif diff --git a/addon/aocl_gemm/kernels/u8s8s32/lpgemm_mn_fringe_amd512vnni.c b/addon/aocl_gemm/kernels/u8s8s32/lpgemm_mn_fringe_amd512vnni.c new file mode 100644 index 0000000000..b202061e6a --- /dev/null +++ b/addon/aocl_gemm/kernels/u8s8s32/lpgemm_mn_fringe_amd512vnni.c @@ -0,0 +1,5283 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include + +#include "blis.h" +#include "lpgemm_kernels.h" +#include "lpgemm_s32_kern_macros.h" + +#ifdef BLIS_KERNELS_ZEN4 +// 5xlt16 int8o32 fringe kernel +LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_5xlt16) +{ + static void* post_ops_labels[] = + { + &&POST_OPS_5xLT16_DISABLE, + &&POST_OPS_BIAS_5xLT16, + &&POST_OPS_RELU_5xLT16, + &&POST_OPS_RELU_SCALE_5xLT16, + &&POST_OPS_DOWNSCALE_5xLT16 + }; + dim_t k_full_pieces = k0 / 4; + dim_t k_partial_pieces = k0 % 4; + + uint32_t a_kfringe_buf = 0; + + // For corner cases. + int32_t buf0[16]; + int32_t buf1[16]; + int32_t buf2[16]; + int32_t buf3[16]; + int32_t buf4[16]; + + { + // Registers to use for accumulating C. + __m512i c_int32_0p0 = _mm512_setzero_epi32(); + + __m512i c_int32_1p0 = _mm512_setzero_epi32(); + + __m512i c_int32_2p0 = _mm512_setzero_epi32(); + + __m512i c_int32_3p0 = _mm512_setzero_epi32(); + + __m512i c_int32_4p0 = _mm512_setzero_epi32(); + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + __m512i b0 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+4]. + __m512i a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[0,0-15] = a[0,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); + + // Broadcast a[1,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[1,0-15] = a[1,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_0, b0 ); + + // Broadcast a[2,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[2,0-15] = a[2,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_2p0 = _mm512_dpbusd_epi32( c_int32_2p0, a_int32_0, b0 ); + + // Broadcast a[3,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 3 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[3,0-15] = a[3,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_3p0 = _mm512_dpbusd_epi32( c_int32_3p0, a_int32_0, b0 ); + + // Broadcast a[4,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 4 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[4,0-15] = a[4,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_4p0 = _mm512_dpbusd_epi32( c_int32_4p0, a_int32_0, b0 ); + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + __m512i b0 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+4]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); + __m512i a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[0,0-15] = a[0,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); + + // Broadcast a[1,kr:kr+4]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[1,0-15] = a[1,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_0, b0 ); + + // Broadcast a[2,kr:kr+4]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 2 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[2,0-15] = a[2,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_2p0 = _mm512_dpbusd_epi32( c_int32_2p0, a_int32_0, b0 ); + + // Broadcast a[3,kr:kr+4]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 3 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[3,0-15] = a[3,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_3p0 = _mm512_dpbusd_epi32( c_int32_3p0, a_int32_0, b0 ); + + // Broadcast a[4,kr:kr+4]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 4 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[4,0-15] = a[4,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_4p0 = _mm512_dpbusd_epi32( c_int32_4p0, a_int32_0, b0 ); + } + + // Load alpha and beta + __m512i selector1 = _mm512_set1_epi32( alpha ); + __m512i selector2 = _mm512_set1_epi32( beta ); + + // Scale by alpha + c_int32_0p0 = _mm512_mullo_epi32( selector1, c_int32_0p0 ); + + c_int32_1p0 = _mm512_mullo_epi32( selector1, c_int32_1p0 ); + + c_int32_2p0 = _mm512_mullo_epi32( selector1, c_int32_2p0 ); + + c_int32_3p0 = _mm512_mullo_epi32( selector1, c_int32_3p0 ); + + c_int32_4p0 = _mm512_mullo_epi32( selector1, c_int32_4p0 ); + + // Scale C by beta. + if ( beta != 0 ) + { + memcpy( buf0, ( c + ( rs_c * 0 ) ), ( n0_rem * sizeof( int32_t ) ) ); + memcpy( buf1, ( c + ( rs_c * 1 ) ), ( n0_rem * sizeof( int32_t ) ) ); + memcpy( buf2, ( c + ( rs_c * 2 ) ), ( n0_rem * sizeof( int32_t ) ) ); + memcpy( buf3, ( c + ( rs_c * 3 ) ), ( n0_rem * sizeof( int32_t ) ) ); + memcpy( buf4, ( c + ( rs_c * 4 ) ), ( n0_rem * sizeof( int32_t ) ) ); + + // c[0,0-15] + selector1 = _mm512_loadu_epi32( buf0 ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); + + // c[1,0-15] + selector1 = _mm512_loadu_epi32( buf1 ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); + + // c[2,0-15] + selector1 = _mm512_loadu_epi32( buf2 ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_2p0 = _mm512_add_epi32( selector1, c_int32_2p0 ); + + // c[3,0-15] + selector1 = _mm512_loadu_epi32( buf3 ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_3p0 = _mm512_add_epi32( selector1, c_int32_3p0 ); + + // c[4,0-15] + selector1 = _mm512_loadu_epi32( buf4 ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_4p0 = _mm512_add_epi32( selector1, c_int32_4p0 ); + } + + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_5xLT16: + { + memcpy( buf0, ( ( int32_t* )post_ops_list_temp->op_args1 + + post_op_c_j ), ( n0_rem * sizeof( int32_t ) ) ); + selector1 = _mm512_loadu_epi32( buf0 ); + + // c[0,0-15] + c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); + + // c[1,0-15] + c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); + + // c[2,0-15] + c_int32_2p0 = _mm512_add_epi32( selector1, c_int32_2p0 ); + + // c[3,0-15] + c_int32_3p0 = _mm512_add_epi32( selector1, c_int32_3p0 ); + + // c[4,0-15] + c_int32_4p0 = _mm512_add_epi32( selector1, c_int32_4p0 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_5xLT16: + { + selector1 = _mm512_setzero_epi32(); + + // c[0,0-15] + c_int32_0p0 = _mm512_max_epi32( selector1, c_int32_0p0 ); + + // c[1,0-15] + c_int32_1p0 = _mm512_max_epi32( selector1, c_int32_1p0 ); + + // c[2,0-15] + c_int32_2p0 = _mm512_max_epi32( selector1, c_int32_2p0 ); + + // c[3,0-15] + c_int32_3p0 = _mm512_max_epi32( selector1, c_int32_3p0 ); + + // c[4,0-15] + c_int32_4p0 = _mm512_max_epi32( selector1, c_int32_4p0 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_5xLT16: + { + selector1 = _mm512_setzero_epi32(); + selector2 = + _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_0p0) + + // c[1, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_1p0) + + // c[2, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_2p0) + + // c[3, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_3p0) + + // c[4, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_4p0) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_DOWNSCALE_5xLT16: + { + memcpy( buf0, ( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j ), ( n0_rem * sizeof( float ) ) ); + selector1 = _mm512_loadu_epi32( buf0 ); + + // c[0, 0-15] + CVT_MULRND_CVT32_CVT8_LT16(c_int32_0p0,selector1,0,0); + + // c[1, 0-15] + CVT_MULRND_CVT32_CVT8_LT16(c_int32_1p0,selector1,1,0); + + // c[2, 0-15] + CVT_MULRND_CVT32_CVT8_LT16(c_int32_2p0,selector1,2,0); + + // c[3, 0-15] + CVT_MULRND_CVT32_CVT8_LT16(c_int32_3p0,selector1,3,0); + + // c[4, 0-15] + CVT_MULRND_CVT32_CVT8_LT16(c_int32_4p0,selector1,4,0); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_5xLT16_DISABLE: + ; + + // Store the results. + // c[0,0-15] + _mm512_storeu_epi32( buf0, c_int32_0p0 ); + + // c[1,0-15] + _mm512_storeu_epi32( buf1, c_int32_1p0 ); + + // c[2,0-15] + _mm512_storeu_epi32( buf2, c_int32_2p0 ); + + // c[3,0-15] + _mm512_storeu_epi32( buf3, c_int32_3p0 ); + + // c[4,0-15] + _mm512_storeu_epi32( buf4, c_int32_4p0 ); + + // Memcpy partial parts. + // c[0,0-15] + memcpy( c + ( rs_c * 0 ) + ( 0*16 ), buf0, ( n0_rem * sizeof( int32_t ) ) ); + + // c[1,0-15] + memcpy( c + ( rs_c * 1 ) + ( 0*16 ), buf1, ( n0_rem * sizeof( int32_t ) ) ); + + // c[2,0-15] + memcpy( c + ( rs_c * 2 ) + ( 0*16 ), buf2, ( n0_rem * sizeof( int32_t ) ) ); + + // c[3,0-15] + memcpy( c + ( rs_c * 3 ) + ( 0*16 ), buf3, ( n0_rem * sizeof( int32_t ) ) ); + + // c[4,0-15] + memcpy( c + ( rs_c * 4 ) + ( 0*16 ), buf4, ( n0_rem * sizeof( int32_t ) ) ); + } +} + +// 4xlt16 int8o32 fringe kernel +LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_4xlt16) +{ + static void* post_ops_labels[] = + { + &&POST_OPS_4xLT16_DISABLE, + &&POST_OPS_BIAS_4xLT16, + &&POST_OPS_RELU_4xLT16, + &&POST_OPS_RELU_SCALE_4xLT16, + &&POST_OPS_DOWNSCALE_4xLT16 + }; + dim_t k_full_pieces = k0 / 4; + dim_t k_partial_pieces = k0 % 4; + + uint32_t a_kfringe_buf = 0; + + // For corner cases. + int32_t buf0[16]; + int32_t buf1[16]; + int32_t buf2[16]; + int32_t buf3[16]; + + { + // Registers to use for accumulating C. + __m512i c_int32_0p0 = _mm512_setzero_epi32(); + + __m512i c_int32_1p0 = _mm512_setzero_epi32(); + + __m512i c_int32_2p0 = _mm512_setzero_epi32(); + + __m512i c_int32_3p0 = _mm512_setzero_epi32(); + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + __m512i b0 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+4]. + __m512i a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[0,0-15] = a[0,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); + + // Broadcast a[1,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[1,0-15] = a[1,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_0, b0 ); + + // Broadcast a[2,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[2,0-15] = a[2,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_2p0 = _mm512_dpbusd_epi32( c_int32_2p0, a_int32_0, b0 ); + + // Broadcast a[3,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 3 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[3,0-15] = a[3,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_3p0 = _mm512_dpbusd_epi32( c_int32_3p0, a_int32_0, b0 ); + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + __m512i b0 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+4]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); + __m512i a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[0,0-15] = a[0,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); + + // Broadcast a[1,kr:kr+4]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[1,0-15] = a[1,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_0, b0 ); + + // Broadcast a[2,kr:kr+4]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 2 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[2,0-15] = a[2,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_2p0 = _mm512_dpbusd_epi32( c_int32_2p0, a_int32_0, b0 ); + + // Broadcast a[3,kr:kr+4]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 3 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[3,0-15] = a[3,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_3p0 = _mm512_dpbusd_epi32( c_int32_3p0, a_int32_0, b0 ); + } + + // Load alpha and beta + __m512i selector1 = _mm512_set1_epi32( alpha ); + __m512i selector2 = _mm512_set1_epi32( beta ); + + // Scale by alpha + c_int32_0p0 = _mm512_mullo_epi32( selector1, c_int32_0p0 ); + + c_int32_1p0 = _mm512_mullo_epi32( selector1, c_int32_1p0 ); + + c_int32_2p0 = _mm512_mullo_epi32( selector1, c_int32_2p0 ); + + c_int32_3p0 = _mm512_mullo_epi32( selector1, c_int32_3p0 ); + + + // Scale C by beta. + if ( beta != 0 ) + { + memcpy( buf0, ( c + ( rs_c * 0 ) ), ( n0_rem * sizeof( int32_t ) ) ); + memcpy( buf1, ( c + ( rs_c * 1 ) ), ( n0_rem * sizeof( int32_t ) ) ); + memcpy( buf2, ( c + ( rs_c * 2 ) ), ( n0_rem * sizeof( int32_t ) ) ); + memcpy( buf3, ( c + ( rs_c * 3 ) ), ( n0_rem * sizeof( int32_t ) ) ); + + // c[0,0-15] + selector1 = _mm512_loadu_epi32( buf0 ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); + + // c[1,0-15] + selector1 = _mm512_loadu_epi32( buf1 ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); + + // c[2,0-15] + selector1 = _mm512_loadu_epi32( buf2 ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_2p0 = _mm512_add_epi32( selector1, c_int32_2p0 ); + + // c[3,0-15] + selector1 = _mm512_loadu_epi32( buf3 ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_3p0 = _mm512_add_epi32( selector1, c_int32_3p0 ); + } + + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_4xLT16: + { + memcpy( buf0, ( ( int32_t* )post_ops_list_temp->op_args1 + + post_op_c_j ), ( n0_rem * sizeof( int32_t ) ) ); + selector1 = _mm512_loadu_epi32( buf0 ); + + // c[0,0-15] + c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); + + // c[1,0-15] + c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); + + // c[2,0-15] + c_int32_2p0 = _mm512_add_epi32( selector1, c_int32_2p0 ); + + // c[3,0-15] + c_int32_3p0 = _mm512_add_epi32( selector1, c_int32_3p0 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_4xLT16: + { + selector1 = _mm512_setzero_epi32(); + + // c[0,0-15] + c_int32_0p0 = _mm512_max_epi32( selector1, c_int32_0p0 ); + + // c[1,0-15] + c_int32_1p0 = _mm512_max_epi32( selector1, c_int32_1p0 ); + + // c[2,0-15] + c_int32_2p0 = _mm512_max_epi32( selector1, c_int32_2p0 ); + + // c[3,0-15] + c_int32_3p0 = _mm512_max_epi32( selector1, c_int32_3p0 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_4xLT16: + { + selector1 = _mm512_setzero_epi32(); + selector2 = + _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_0p0) + + // c[1, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_1p0) + + // c[2, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_2p0) + + // c[3, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_3p0) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_DOWNSCALE_4xLT16: + { + memcpy( buf0, ( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j ), ( n0_rem * sizeof( float ) ) ); + selector1 = _mm512_loadu_epi32( buf0 ); + + // c[0, 0-15] + CVT_MULRND_CVT32_CVT8_LT16(c_int32_0p0,selector1,0,0); + + // c[1, 0-15] + CVT_MULRND_CVT32_CVT8_LT16(c_int32_1p0,selector1,1,0); + + // c[2, 0-15] + CVT_MULRND_CVT32_CVT8_LT16(c_int32_2p0,selector1,2,0); + + // c[3, 0-15] + CVT_MULRND_CVT32_CVT8_LT16(c_int32_3p0,selector1,3,0); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_4xLT16_DISABLE: + ; + + // Store the results. + // c[0,0-15] + _mm512_storeu_epi32( buf0, c_int32_0p0 ); + + // c[1,0-15] + _mm512_storeu_epi32( buf1, c_int32_1p0 ); + + // c[2,0-15] + _mm512_storeu_epi32( buf2, c_int32_2p0 ); + + // c[3,0-15] + _mm512_storeu_epi32( buf3, c_int32_3p0 ); + + // Memcpy partial parts. + // c[0,0-15] + memcpy( c + ( rs_c * 0 ) + ( 0*16 ), buf0, ( n0_rem * sizeof( int32_t ) ) ); + + // c[1,0-15] + memcpy( c + ( rs_c * 1 ) + ( 0*16 ), buf1, ( n0_rem * sizeof( int32_t ) ) ); + + // c[2,0-15] + memcpy( c + ( rs_c * 2 ) + ( 0*16 ), buf2, ( n0_rem * sizeof( int32_t ) ) ); + + // c[3,0-15] + memcpy( c + ( rs_c * 3 ) + ( 0*16 ), buf3, ( n0_rem * sizeof( int32_t ) ) ); + } +} + +// 3xlt16 int8o32 fringe kernel +LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_3xlt16) +{ + static void* post_ops_labels[] = + { + &&POST_OPS_3xLT16_DISABLE, + &&POST_OPS_BIAS_3xLT16, + &&POST_OPS_RELU_3xLT16, + &&POST_OPS_RELU_SCALE_3xLT16, + &&POST_OPS_DOWNSCALE_3xLT16 + }; + dim_t k_full_pieces = k0 / 4; + dim_t k_partial_pieces = k0 % 4; + + uint32_t a_kfringe_buf = 0; + + // For corner cases. + int32_t buf0[16]; + int32_t buf1[16]; + int32_t buf2[16]; + + { + // Registers to use for accumulating C. + __m512i c_int32_0p0 = _mm512_setzero_epi32(); + + __m512i c_int32_1p0 = _mm512_setzero_epi32(); + + __m512i c_int32_2p0 = _mm512_setzero_epi32(); + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + __m512i b0 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+4]. + __m512i a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[0,0-15] = a[0,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); + + // Broadcast a[1,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[1,0-15] = a[1,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_0, b0 ); + + // Broadcast a[2,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[2,0-15] = a[2,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_2p0 = _mm512_dpbusd_epi32( c_int32_2p0, a_int32_0, b0 ); + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + __m512i b0 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+4]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); + __m512i a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[0,0-15] = a[0,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); + + // Broadcast a[1,kr:kr+4]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[1,0-15] = a[1,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_0, b0 ); + + // Broadcast a[2,kr:kr+4]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 2 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[2,0-15] = a[2,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_2p0 = _mm512_dpbusd_epi32( c_int32_2p0, a_int32_0, b0 ); + } + + // Load alpha and beta + __m512i selector1 = _mm512_set1_epi32( alpha ); + __m512i selector2 = _mm512_set1_epi32( beta ); + + // Scale by alpha + c_int32_0p0 = _mm512_mullo_epi32( selector1, c_int32_0p0 ); + + c_int32_1p0 = _mm512_mullo_epi32( selector1, c_int32_1p0 ); + + c_int32_2p0 = _mm512_mullo_epi32( selector1, c_int32_2p0 ); + + // Scale C by beta. + if ( beta != 0 ) + { + memcpy( buf0, ( c + ( rs_c * 0 ) ), ( n0_rem * sizeof( int32_t ) ) ); + memcpy( buf1, ( c + ( rs_c * 1 ) ), ( n0_rem * sizeof( int32_t ) ) ); + memcpy( buf2, ( c + ( rs_c * 2 ) ), ( n0_rem * sizeof( int32_t ) ) ); + + // c[0,0-15] + selector1 = _mm512_loadu_epi32( buf0 ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); + + // c[1,0-15] + selector1 = _mm512_loadu_epi32( buf1 ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); + + // c[2,0-15] + selector1 = _mm512_loadu_epi32( buf2 ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_2p0 = _mm512_add_epi32( selector1, c_int32_2p0 ); + } + + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_3xLT16: + { + memcpy( buf0, ( ( int32_t* )post_ops_list_temp->op_args1 + + post_op_c_j ), ( n0_rem * sizeof( int32_t ) ) ); + selector1 = _mm512_loadu_epi32( buf0 ); + + // c[0,0-15] + c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); + + // c[1,0-15] + c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); + + // c[2,0-15] + c_int32_2p0 = _mm512_add_epi32( selector1, c_int32_2p0 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_3xLT16: + { + selector1 = _mm512_setzero_epi32(); + + // c[0,0-15] + c_int32_0p0 = _mm512_max_epi32( selector1, c_int32_0p0 ); + + // c[1,0-15] + c_int32_1p0 = _mm512_max_epi32( selector1, c_int32_1p0 ); + + // c[2,0-15] + c_int32_2p0 = _mm512_max_epi32( selector1, c_int32_2p0 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_3xLT16: + { + selector1 = _mm512_setzero_epi32(); + selector2 = + _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_0p0) + + // c[1, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_1p0) + + // c[2, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_2p0) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_DOWNSCALE_3xLT16: + { + memcpy( buf0, ( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j ), ( n0_rem * sizeof( float ) ) ); + selector1 = _mm512_loadu_epi32( buf0 ); + + // c[0, 0-15] + CVT_MULRND_CVT32_CVT8_LT16(c_int32_0p0,selector1,0,0); + + // c[1, 0-15] + CVT_MULRND_CVT32_CVT8_LT16(c_int32_1p0,selector1,1,0); + + // c[2, 0-15] + CVT_MULRND_CVT32_CVT8_LT16(c_int32_2p0,selector1,2,0); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_3xLT16_DISABLE: + ; + + // Store the results. + // c[0,0-15] + _mm512_storeu_epi32( buf0, c_int32_0p0 ); + + // c[1,0-15] + _mm512_storeu_epi32( buf1, c_int32_1p0 ); + + // c[2,0-15] + _mm512_storeu_epi32( buf2, c_int32_2p0 ); + + // Memcpy partial parts. + // c[0,0-15] + memcpy( c + ( rs_c * 0 ) + ( 0*16 ), buf0, ( n0_rem * sizeof( int32_t ) ) ); + + // c[1,0-15] + memcpy( c + ( rs_c * 1 ) + ( 0*16 ), buf1, ( n0_rem * sizeof( int32_t ) ) ); + + // c[2,0-15] + memcpy( c + ( rs_c * 2 ) + ( 0*16 ), buf2, ( n0_rem * sizeof( int32_t ) ) ); + } +} + +// 2xlt16 int8o32 fringe kernel +LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_2xlt16) +{ + static void* post_ops_labels[] = + { + &&POST_OPS_2xLT16_DISABLE, + &&POST_OPS_BIAS_2xLT16, + &&POST_OPS_RELU_2xLT16, + &&POST_OPS_RELU_SCALE_2xLT16, + &&POST_OPS_DOWNSCALE_2xLT16 + }; + dim_t k_full_pieces = k0 / 4; + dim_t k_partial_pieces = k0 % 4; + + uint32_t a_kfringe_buf = 0; + + // For corner cases. + int32_t buf0[16]; + int32_t buf1[16]; + + { + // Registers to use for accumulating C. + __m512i c_int32_0p0 = _mm512_setzero_epi32(); + + __m512i c_int32_1p0 = _mm512_setzero_epi32(); + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + __m512i b0 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+4]. + __m512i a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[0,0-15] = a[0,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); + + // Broadcast a[1,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[1,0-15] = a[1,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_0, b0 ); + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + __m512i b0 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+4]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); + __m512i a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[0,0-15] = a[0,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); + + // Broadcast a[1,kr:kr+4]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[1,0-15] = a[1,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_0, b0 ); + } + + // Load alpha and beta + __m512i selector1 = _mm512_set1_epi32( alpha ); + __m512i selector2 = _mm512_set1_epi32( beta ); + + // Scale by alpha + c_int32_0p0 = _mm512_mullo_epi32( selector1, c_int32_0p0 ); + + c_int32_1p0 = _mm512_mullo_epi32( selector1, c_int32_1p0 ); + + // Scale C by beta. + if ( beta != 0 ) + { + memcpy( buf0, ( c + ( rs_c * 0 ) ), ( n0_rem * sizeof( int32_t ) ) ); + memcpy( buf1, ( c + ( rs_c * 1 ) ), ( n0_rem * sizeof( int32_t ) ) ); + + // c[0,0-15] + selector1 = _mm512_loadu_epi32( buf0 ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); + + // c[1,0-15] + selector1 = _mm512_loadu_epi32( buf1 ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); + } + + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_2xLT16: + { + memcpy( buf0, ( ( int32_t* )post_ops_list_temp->op_args1 + + post_op_c_j ), ( n0_rem * sizeof( int32_t ) ) ); + selector1 = _mm512_loadu_epi32( buf0 ); + + // c[0,0-15] + c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); + + // c[1,0-15] + c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_2xLT16: + { + selector1 = _mm512_setzero_epi32(); + + // c[0,0-15] + c_int32_0p0 = _mm512_max_epi32( selector1, c_int32_0p0 ); + + // c[1,0-15] + c_int32_1p0 = _mm512_max_epi32( selector1, c_int32_1p0 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_2xLT16: + { + selector1 = _mm512_setzero_epi32(); + selector2 = + _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_0p0) + + // c[1, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_1p0) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_DOWNSCALE_2xLT16: + { + memcpy( buf0, ( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j ), ( n0_rem * sizeof( float ) ) ); + selector1 = _mm512_loadu_epi32( buf0 ); + + // c[0, 0-15] + CVT_MULRND_CVT32_CVT8_LT16(c_int32_0p0,selector1,0,0); + + // c[1, 0-15] + CVT_MULRND_CVT32_CVT8_LT16(c_int32_1p0,selector1,1,0); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_2xLT16_DISABLE: + ; + + // Store the results. + // c[0,0-15] + _mm512_storeu_epi32( buf0, c_int32_0p0 ); + + // c[1,0-15] + _mm512_storeu_epi32( buf1, c_int32_1p0 ); + + // Memcpy partial parts. + // c[0,0-15] + memcpy( c + ( rs_c * 0 ) + ( 0*16 ), buf0, ( n0_rem * sizeof( int32_t ) ) ); + + // c[1,0-15] + memcpy( c + ( rs_c * 1 ) + ( 0*16 ), buf1, ( n0_rem * sizeof( int32_t ) ) ); + } +} + +// 1xlt16 int8o32 fringe kernel +LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_1xlt16) +{ + static void* post_ops_labels[] = + { + &&POST_OPS_1xLT16_DISABLE, + &&POST_OPS_BIAS_1xLT16, + &&POST_OPS_RELU_1xLT16, + &&POST_OPS_RELU_SCALE_1xLT16, + &&POST_OPS_DOWNSCALE_1xLT16 + }; + dim_t k_full_pieces = k0 / 4; + dim_t k_partial_pieces = k0 % 4; + + uint32_t a_kfringe_buf = 0; + + // For corner cases. + int32_t buf0[16]; + + { + // Registers to use for accumulating C. + __m512i c_int32_0p0 = _mm512_setzero_epi32(); + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + __m512i b0 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+4]. + __m512i a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[0,0-15] = a[0,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + __m512i b0 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+4]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); + __m512i a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[0,0-15] = a[0,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); + } + + // Load alpha and beta + __m512i selector1 = _mm512_set1_epi32( alpha ); + __m512i selector2 = _mm512_set1_epi32( beta ); + + // Scale by alpha + c_int32_0p0 = _mm512_mullo_epi32( selector1, c_int32_0p0 ); + + // Scale C by beta. + if ( beta != 0 ) + { + memcpy( buf0, ( c + ( rs_c * 0 ) ), ( n0_rem * sizeof( int32_t ) ) ); + + // c[0,0-15] + selector1 = _mm512_loadu_epi32( buf0 ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); + } + + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_1xLT16: + { + memcpy( buf0, ( ( int32_t* )post_ops_list_temp->op_args1 + + post_op_c_j ), ( n0_rem * sizeof( int32_t ) ) ); + selector1 = _mm512_loadu_epi32( buf0 ); + + // c[0,0-15] + c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_1xLT16: + { + selector1 = _mm512_setzero_epi32(); + + // c[0,0-15] + c_int32_0p0 = _mm512_max_epi32( selector1, c_int32_0p0 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_1xLT16: + { + selector1 = _mm512_setzero_epi32(); + selector2 = + _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_0p0) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_DOWNSCALE_1xLT16: + { + memcpy( buf0, ( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j ), ( n0_rem * sizeof( float ) ) ); + selector1 = _mm512_loadu_epi32( buf0 ); + + // c[0, 0-15] + CVT_MULRND_CVT32_CVT8_LT16(c_int32_0p0,selector1,0,0); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_1xLT16_DISABLE: + ; + + // Store the results. + // c[0,0-15] + _mm512_storeu_epi32( buf0, c_int32_0p0 ); + + // Memcpy partial parts. + // c[0,0-15] + memcpy( c + ( rs_c * 0 ) + ( 0*16 ), buf0, ( n0_rem * sizeof( int32_t ) ) ); + } +} + +// 5x16 int8o32 kernel +LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_5x16) +{ + static void* post_ops_labels[] = + { + &&POST_OPS_5x16_DISABLE, + &&POST_OPS_BIAS_5x16, + &&POST_OPS_RELU_5x16, + &&POST_OPS_RELU_SCALE_5x16, + &&POST_OPS_DOWNSCALE_5x16 + }; + dim_t k_full_pieces = k0 / 4; + dim_t k_partial_pieces = k0 % 4; + + uint32_t a_kfringe_buf = 0; + + // Registers to use for accumulating C. + __m512i c_int32_0p0 = _mm512_setzero_epi32(); + + __m512i c_int32_1p0 = _mm512_setzero_epi32(); + + __m512i c_int32_2p0 = _mm512_setzero_epi32(); + + __m512i c_int32_3p0 = _mm512_setzero_epi32(); + + __m512i c_int32_4p0 = _mm512_setzero_epi32(); + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + __m512i b0 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+4]. + __m512i a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[0,0-15] = a[0,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); + + // Broadcast a[1,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[1,0-15] = a[1,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_0, b0 ); + + // Broadcast a[2,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[2,0-15] = a[2,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_2p0 = _mm512_dpbusd_epi32( c_int32_2p0, a_int32_0, b0 ); + + // Broadcast a[3,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 3 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[3,0-15] = a[3,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_3p0 = _mm512_dpbusd_epi32( c_int32_3p0, a_int32_0, b0 ); + + // Broadcast a[4,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 4 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[4,0-15] = a[4,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_4p0 = _mm512_dpbusd_epi32( c_int32_4p0, a_int32_0, b0 ); + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + __m512i b0 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+4]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); + __m512i a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[0,0-15] = a[0,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); + + // Broadcast a[1,kr:kr+4]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[1,0-15] = a[1,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_0, b0 ); + + // Broadcast a[2,kr:kr+4]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 2 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[2,0-15] = a[2,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_2p0 = _mm512_dpbusd_epi32( c_int32_2p0, a_int32_0, b0 ); + + // Broadcast a[3,kr:kr+4]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 3 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[3,0-15] = a[3,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_3p0 = _mm512_dpbusd_epi32( c_int32_3p0, a_int32_0, b0 ); + + // Broadcast a[4,kr:kr+4]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 4 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[4,0-15] = a[4,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_4p0 = _mm512_dpbusd_epi32( c_int32_4p0, a_int32_0, b0 ); + } + + // Load alpha and beta + __m512i selector1 = _mm512_set1_epi32( alpha ); + __m512i selector2 = _mm512_set1_epi32( beta ); + + // Scale by alpha + c_int32_0p0 = _mm512_mullo_epi32( selector1, c_int32_0p0 ); + + c_int32_1p0 = _mm512_mullo_epi32( selector1, c_int32_1p0 ); + + c_int32_2p0 = _mm512_mullo_epi32( selector1, c_int32_2p0 ); + + c_int32_3p0 = _mm512_mullo_epi32( selector1, c_int32_3p0 ); + + c_int32_4p0 = _mm512_mullo_epi32( selector1, c_int32_4p0 ); + + // Scale C by beta. + if ( beta != 0 ) + { + // c[0,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); + + // c[1,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 1 ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); + + // c[2,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 2 ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_2p0 = _mm512_add_epi32( selector1, c_int32_2p0 ); + + // c[3,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 3 ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_3p0 = _mm512_add_epi32( selector1, c_int32_3p0 ); + + // c[4,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 4 ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_4p0 = _mm512_add_epi32( selector1, c_int32_4p0 ); + } + + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_5x16: + { + selector1 = + _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + + post_op_c_j ); + + // c[0,0-15] + c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); + + // c[1,0-15] + c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); + + // c[2,0-15] + c_int32_2p0 = _mm512_add_epi32( selector1, c_int32_2p0 ); + + // c[3,0-15] + c_int32_3p0 = _mm512_add_epi32( selector1, c_int32_3p0 ); + + // c[4,0-15] + c_int32_4p0 = _mm512_add_epi32( selector1, c_int32_4p0 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_5x16: + { + selector1 = _mm512_setzero_epi32(); + + // c[0,0-15] + c_int32_0p0 = _mm512_max_epi32( selector1, c_int32_0p0 ); + + // c[1,0-15] + c_int32_1p0 = _mm512_max_epi32( selector1, c_int32_1p0 ); + + // c[2,0-15] + c_int32_2p0 = _mm512_max_epi32( selector1, c_int32_2p0 ); + + // c[3,0-15] + c_int32_3p0 = _mm512_max_epi32( selector1, c_int32_3p0 ); + + // c[4,0-15] + c_int32_4p0 = _mm512_max_epi32( selector1, c_int32_4p0 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_5x16: + { + selector1 = _mm512_setzero_epi32(); + selector2 = + _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_0p0) + + // c[1, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_1p0) + + // c[2, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_2p0) + + // c[3, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_3p0) + + // c[4, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_4p0) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_DOWNSCALE_5x16: + { + selector1 = + _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j + ( 0 * 16 ) ); + + // c[0, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_0p0,selector1,0,0); + + // c[1, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_1p0,selector1,1,0); + + // c[2, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_2p0,selector1,2,0); + + // c[3, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_3p0,selector1,3,0); + + // c[4, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_4p0,selector1,4,0); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_5x16_DISABLE: + ; + + // Store the results. + // c[0,0-15] + _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 0*16 ), c_int32_0p0 ); + + // c[1,0-15] + _mm512_storeu_epi32( c + ( rs_c * 1 ) + ( 0*16 ), c_int32_1p0 ); + + // c[2,0-15] + _mm512_storeu_epi32( c + ( rs_c * 2 ) + ( 0*16 ), c_int32_2p0 ); + + // c[3,0-15] + _mm512_storeu_epi32( c + ( rs_c * 3 ) + ( 0*16 ), c_int32_3p0 ); + + // c[4,0-15] + _mm512_storeu_epi32( c + ( rs_c * 4 ) + ( 0*16 ), c_int32_4p0 ); +} + +// 4x16 int8o32 kernel +LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_4x16) +{ + static void* post_ops_labels[] = + { + &&POST_OPS_4x16_DISABLE, + &&POST_OPS_BIAS_4x16, + &&POST_OPS_RELU_4x16, + &&POST_OPS_RELU_SCALE_4x16, + &&POST_OPS_DOWNSCALE_4x16 + }; + dim_t k_full_pieces = k0 / 4; + dim_t k_partial_pieces = k0 % 4; + + uint32_t a_kfringe_buf = 0; + + // Registers to use for accumulating C. + __m512i c_int32_0p0 = _mm512_setzero_epi32(); + + __m512i c_int32_1p0 = _mm512_setzero_epi32(); + + __m512i c_int32_2p0 = _mm512_setzero_epi32(); + + __m512i c_int32_3p0 = _mm512_setzero_epi32(); + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + __m512i b0 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+4]. + __m512i a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[0,0-15] = a[0,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); + + // Broadcast a[1,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[1,0-15] = a[1,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_0, b0 ); + + // Broadcast a[2,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[2,0-15] = a[2,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_2p0 = _mm512_dpbusd_epi32( c_int32_2p0, a_int32_0, b0 ); + + // Broadcast a[3,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 3 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[3,0-15] = a[3,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_3p0 = _mm512_dpbusd_epi32( c_int32_3p0, a_int32_0, b0 ); + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + __m512i b0 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+4]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); + __m512i a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[0,0-15] = a[0,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); + + // Broadcast a[1,kr:kr+4]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[1,0-15] = a[1,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_0, b0 ); + + // Broadcast a[2,kr:kr+4]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 2 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[2,0-15] = a[2,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_2p0 = _mm512_dpbusd_epi32( c_int32_2p0, a_int32_0, b0 ); + + // Broadcast a[3,kr:kr+4]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 3 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[3,0-15] = a[3,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_3p0 = _mm512_dpbusd_epi32( c_int32_3p0, a_int32_0, b0 ); + } + + // Load alpha and beta + __m512i selector1 = _mm512_set1_epi32( alpha ); + __m512i selector2 = _mm512_set1_epi32( beta ); + + // Scale by alpha + c_int32_0p0 = _mm512_mullo_epi32( selector1, c_int32_0p0 ); + + c_int32_1p0 = _mm512_mullo_epi32( selector1, c_int32_1p0 ); + + c_int32_2p0 = _mm512_mullo_epi32( selector1, c_int32_2p0 ); + + c_int32_3p0 = _mm512_mullo_epi32( selector1, c_int32_3p0 ); + + // Scale C by beta. + if ( beta != 0 ) + { + // c[0,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); + + // c[1,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 1 ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); + + // c[2,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 2 ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_2p0 = _mm512_add_epi32( selector1, c_int32_2p0 ); + + // c[3,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 3 ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_3p0 = _mm512_add_epi32( selector1, c_int32_3p0 ); + } + + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_4x16: + { + selector1 = + _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + + post_op_c_j ); + + // c[0,0-15] + c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); + + // c[1,0-15] + c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); + + // c[2,0-15] + c_int32_2p0 = _mm512_add_epi32( selector1, c_int32_2p0 ); + + // c[3,0-15] + c_int32_3p0 = _mm512_add_epi32( selector1, c_int32_3p0 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_4x16: + { + selector1 = _mm512_setzero_epi32(); + + // c[0,0-15] + c_int32_0p0 = _mm512_max_epi32( selector1, c_int32_0p0 ); + + // c[1,0-15] + c_int32_1p0 = _mm512_max_epi32( selector1, c_int32_1p0 ); + + // c[2,0-15] + c_int32_2p0 = _mm512_max_epi32( selector1, c_int32_2p0 ); + + // c[3,0-15] + c_int32_3p0 = _mm512_max_epi32( selector1, c_int32_3p0 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_4x16: + { + selector1 = _mm512_setzero_epi32(); + selector2 = + _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_0p0) + + // c[1, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_1p0) + + // c[2, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_2p0) + + // c[3, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_3p0) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_DOWNSCALE_4x16: + { + selector1 = + _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j + ( 0 * 16 ) ); + + // c[0, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_0p0,selector1,0,0); + + // c[1, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_1p0,selector1,1,0); + + // c[2, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_2p0,selector1,2,0); + + // c[3, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_3p0,selector1,3,0); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_4x16_DISABLE: + ; + + // Store the results. + // c[0,0-15] + _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 0*16 ), c_int32_0p0 ); + + // c[1,0-15] + _mm512_storeu_epi32( c + ( rs_c * 1 ) + ( 0*16 ), c_int32_1p0 ); + + // c[2,0-15] + _mm512_storeu_epi32( c + ( rs_c * 2 ) + ( 0*16 ), c_int32_2p0 ); + + // c[3,0-15] + _mm512_storeu_epi32( c + ( rs_c * 3 ) + ( 0*16 ), c_int32_3p0 ); +} + +// 3x16 int8o32 kernel +LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_3x16) +{ + static void* post_ops_labels[] = + { + &&POST_OPS_3x16_DISABLE, + &&POST_OPS_BIAS_3x16, + &&POST_OPS_RELU_3x16, + &&POST_OPS_RELU_SCALE_3x16, + &&POST_OPS_DOWNSCALE_3x16 + }; + dim_t k_full_pieces = k0 / 4; + dim_t k_partial_pieces = k0 % 4; + + uint32_t a_kfringe_buf = 0; + + // Registers to use for accumulating C. + __m512i c_int32_0p0 = _mm512_setzero_epi32(); + + __m512i c_int32_1p0 = _mm512_setzero_epi32(); + + __m512i c_int32_2p0 = _mm512_setzero_epi32(); + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + __m512i b0 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+4]. + __m512i a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[0,0-15] = a[0,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); + + // Broadcast a[1,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[1,0-15] = a[1,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_0, b0 ); + + // Broadcast a[2,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[2,0-15] = a[2,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_2p0 = _mm512_dpbusd_epi32( c_int32_2p0, a_int32_0, b0 ); + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + __m512i b0 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+4]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); + __m512i a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[0,0-15] = a[0,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); + + // Broadcast a[1,kr:kr+4]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[1,0-15] = a[1,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_0, b0 ); + + // Broadcast a[2,kr:kr+4]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 2 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[2,0-15] = a[2,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_2p0 = _mm512_dpbusd_epi32( c_int32_2p0, a_int32_0, b0 ); + } + + // Load alpha and beta + __m512i selector1 = _mm512_set1_epi32( alpha ); + __m512i selector2 = _mm512_set1_epi32( beta ); + + // Scale by alpha + c_int32_0p0 = _mm512_mullo_epi32( selector1, c_int32_0p0 ); + + c_int32_1p0 = _mm512_mullo_epi32( selector1, c_int32_1p0 ); + + c_int32_2p0 = _mm512_mullo_epi32( selector1, c_int32_2p0 ); + + // Scale C by beta. + if ( beta != 0 ) + { + // c[0,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); + + // c[1,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 1 ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); + + // c[2,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 2 ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_2p0 = _mm512_add_epi32( selector1, c_int32_2p0 ); + } + + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_3x16: + { + selector1 = + _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + + post_op_c_j ); + + // c[0,0-15] + c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); + + // c[1,0-15] + c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); + + // c[2,0-15] + c_int32_2p0 = _mm512_add_epi32( selector1, c_int32_2p0 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_3x16: + { + selector1 = _mm512_setzero_epi32(); + + // c[0,0-15] + c_int32_0p0 = _mm512_max_epi32( selector1, c_int32_0p0 ); + + // c[1,0-15] + c_int32_1p0 = _mm512_max_epi32( selector1, c_int32_1p0 ); + + // c[2,0-15] + c_int32_2p0 = _mm512_max_epi32( selector1, c_int32_2p0 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_3x16: + { + selector1 = _mm512_setzero_epi32(); + selector2 = + _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_0p0) + + // c[1, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_1p0) + + // c[2, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_2p0) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_DOWNSCALE_3x16: + { + selector1 = + _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j + ( 0 * 16 ) ); + + // c[0, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_0p0,selector1,0,0); + + // c[1, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_1p0,selector1,1,0); + + // c[2, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_2p0,selector1,2,0); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_3x16_DISABLE: + ; + + // Store the results. + // c[0,0-15] + _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 0*16 ), c_int32_0p0 ); + + // c[1,0-15] + _mm512_storeu_epi32( c + ( rs_c * 1 ) + ( 0*16 ), c_int32_1p0 ); + + // c[2,0-15] + _mm512_storeu_epi32( c + ( rs_c * 2 ) + ( 0*16 ), c_int32_2p0 ); +} + +// 2x16 int8o32 kernel +LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_2x16) +{ + static void* post_ops_labels[] = + { + &&POST_OPS_2x16_DISABLE, + &&POST_OPS_BIAS_2x16, + &&POST_OPS_RELU_2x16, + &&POST_OPS_RELU_SCALE_2x16, + &&POST_OPS_DOWNSCALE_2x16 + }; + dim_t k_full_pieces = k0 / 4; + dim_t k_partial_pieces = k0 % 4; + + uint32_t a_kfringe_buf = 0; + + // Registers to use for accumulating C. + __m512i c_int32_0p0 = _mm512_setzero_epi32(); + + __m512i c_int32_1p0 = _mm512_setzero_epi32(); + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + __m512i b0 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+4]. + __m512i a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[0,0-15] = a[0,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); + + // Broadcast a[1,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[1,0-15] = a[1,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_0, b0 ); + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + __m512i b0 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+4]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); + __m512i a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[0,0-15] = a[0,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); + + // Broadcast a[1,kr:kr+4]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[1,0-15] = a[1,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_0, b0 ); + } + + // Load alpha and beta + __m512i selector1 = _mm512_set1_epi32( alpha ); + __m512i selector2 = _mm512_set1_epi32( beta ); + + // Scale by alpha + c_int32_0p0 = _mm512_mullo_epi32( selector1, c_int32_0p0 ); + + c_int32_1p0 = _mm512_mullo_epi32( selector1, c_int32_1p0 ); + + // Scale C by beta. + if ( beta != 0 ) + { + // c[0,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); + + // c[1,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 1 ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); + } + + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_2x16: + { + selector1 = + _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + + post_op_c_j ); + + // c[0,0-15] + c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); + + // c[1,0-15] + c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_2x16: + { + selector1 = _mm512_setzero_epi32(); + + // c[0,0-15] + c_int32_0p0 = _mm512_max_epi32( selector1, c_int32_0p0 ); + + // c[1,0-15] + c_int32_1p0 = _mm512_max_epi32( selector1, c_int32_1p0 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_2x16: + { + selector1 = _mm512_setzero_epi32(); + selector2 = + _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_0p0) + + // c[1, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_1p0) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_DOWNSCALE_2x16: + { + selector1 = + _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j + ( 0 * 16 ) ); + + // c[0, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_0p0,selector1,0,0); + + // c[1, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_1p0,selector1,1,0); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_2x16_DISABLE: + ; + + // Store the results. + // c[0,0-15] + _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 0*16 ), c_int32_0p0 ); + + // c[1,0-15] + _mm512_storeu_epi32( c + ( rs_c * 1 ) + ( 0*16 ), c_int32_1p0 ); +} + +// 1x16 int8o32 kernel +LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_1x16) +{ + static void* post_ops_labels[] = + { + &&POST_OPS_1x16_DISABLE, + &&POST_OPS_BIAS_1x16, + &&POST_OPS_RELU_1x16, + &&POST_OPS_RELU_SCALE_1x16, + &&POST_OPS_DOWNSCALE_1x16 + }; + dim_t k_full_pieces = k0 / 4; + dim_t k_partial_pieces = k0 % 4; + + uint32_t a_kfringe_buf = 0; + + // Registers to use for accumulating C. + __m512i c_int32_0p0 = _mm512_setzero_epi32(); + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + __m512i b0 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+4]. + __m512i a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[0,0-15] = a[0,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + __m512i b0 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+4]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); + __m512i a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[0,0-15] = a[0,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); + } + + // Load alpha and beta + __m512i selector1 = _mm512_set1_epi32( alpha ); + __m512i selector2 = _mm512_set1_epi32( beta ); + + // Scale by alpha + c_int32_0p0 = _mm512_mullo_epi32( selector1, c_int32_0p0 ); + + // Scale C by beta. + if ( beta != 0 ) + { + // c[0,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); + } + + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_1x16: + { + selector1 = + _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + + post_op_c_j ); + + // c[0,0-15] + c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_1x16: + { + selector1 = _mm512_setzero_epi32(); + + // c[0,0-15] + c_int32_0p0 = _mm512_max_epi32( selector1, c_int32_0p0 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_1x16: + { + selector1 = _mm512_setzero_epi32(); + selector2 = + _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_0p0) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_DOWNSCALE_1x16: + { + selector1 = + _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j + ( 0 * 16 ) ); + + // c[0, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_0p0,selector1,0,0); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_1x16_DISABLE: + ; + + // Store the results. + // c[0,0-15] + _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 0*16 ), c_int32_0p0 ); +} + +// 5x32 int8o32 kernel +LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_5x32) +{ + static void* post_ops_labels[] = + { + &&POST_OPS_5x32_DISABLE, + &&POST_OPS_BIAS_5x32, + &&POST_OPS_RELU_5x32, + &&POST_OPS_RELU_SCALE_5x32, + &&POST_OPS_DOWNSCALE_5x32 + }; + dim_t k_full_pieces = k0 / 4; + dim_t k_partial_pieces = k0 % 4; + + uint32_t a_kfringe_buf = 0; + + // B matrix storage. + __m512i b0; + __m512i b1; + + // A matrix storage. + __m512i a_int32_0; + + // Registers to use for accumulating C. + __m512i c_int32_0p0 = _mm512_setzero_epi32(); + __m512i c_int32_0p1 = _mm512_setzero_epi32(); + + __m512i c_int32_1p0 = _mm512_setzero_epi32(); + __m512i c_int32_1p1 = _mm512_setzero_epi32(); + + __m512i c_int32_2p0 = _mm512_setzero_epi32(); + __m512i c_int32_2p1 = _mm512_setzero_epi32(); + + __m512i c_int32_3p0 = _mm512_setzero_epi32(); + __m512i c_int32_3p1 = _mm512_setzero_epi32(); + + __m512i c_int32_4p0 = _mm512_setzero_epi32(); + __m512i c_int32_4p1 = _mm512_setzero_epi32(); + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + b0 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + b1 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 1 ) ); + + // Broadcast a[0,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[0,0-31] = a[0,kr:kr+4]*b[kr:kr+4,0-31] + c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); + c_int32_0p1 = _mm512_dpbusd_epi32( c_int32_0p1, a_int32_0, b1 ); + + // Broadcast a[1,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[1,0-31] = a[1,kr:kr+4]*b[kr:kr+4,0-31] + c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_0, b0 ); + c_int32_1p1 = _mm512_dpbusd_epi32( c_int32_1p1, a_int32_0, b1 ); + + // Broadcast a[2,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[2,0-31] = a[2,kr:kr+4]*b[kr:kr+4,0-31] + c_int32_2p0 = _mm512_dpbusd_epi32( c_int32_2p0, a_int32_0, b0 ); + c_int32_2p1 = _mm512_dpbusd_epi32( c_int32_2p1, a_int32_0, b1 ); + + // Broadcast a[3,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 3 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[3,0-31] = a[3,kr:kr+4]*b[kr:kr+4,0-31] + c_int32_3p0 = _mm512_dpbusd_epi32( c_int32_3p0, a_int32_0, b0 ); + c_int32_3p1 = _mm512_dpbusd_epi32( c_int32_3p1, a_int32_0, b1 ); + + // Broadcast a[4,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 4 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[4,0-31] = a[4,kr:kr+4]*b[kr:kr+4,0-31] + c_int32_4p0 = _mm512_dpbusd_epi32( c_int32_4p0, a_int32_0, b0 ); + c_int32_4p1 = _mm512_dpbusd_epi32( c_int32_4p1, a_int32_0, b1 ); + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + b0 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + b1 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); + + // Broadcast a[0,kr:kr+4]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[0,0-31] = a[0,kr:kr+4]*b[kr:kr+4,0-31] + c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); + c_int32_0p1 = _mm512_dpbusd_epi32( c_int32_0p1, a_int32_0, b1 ); + + // Broadcast a[1,kr:kr+4]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[1,0-31] = a[1,kr:kr+4]*b[kr:kr+4,0-31] + c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_0, b0 ); + c_int32_1p1 = _mm512_dpbusd_epi32( c_int32_1p1, a_int32_0, b1 ); + + // Broadcast a[2,kr:kr+4]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 2 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[2,0-31] = a[2,kr:kr+4]*b[kr:kr+4,0-31] + c_int32_2p0 = _mm512_dpbusd_epi32( c_int32_2p0, a_int32_0, b0 ); + c_int32_2p1 = _mm512_dpbusd_epi32( c_int32_2p1, a_int32_0, b1 ); + + // Broadcast a[3,kr:kr+4]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 3 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[3,0-31] = a[3,kr:kr+4]*b[kr:kr+4,0-31] + c_int32_3p0 = _mm512_dpbusd_epi32( c_int32_3p0, a_int32_0, b0 ); + c_int32_3p1 = _mm512_dpbusd_epi32( c_int32_3p1, a_int32_0, b1 ); + + // Broadcast a[4,kr:kr+4]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 4 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[4,0-31] = a[4,kr:kr+4]*b[kr:kr+4,0-31] + c_int32_4p0 = _mm512_dpbusd_epi32( c_int32_4p0, a_int32_0, b0 ); + c_int32_4p1 = _mm512_dpbusd_epi32( c_int32_4p1, a_int32_0, b1 ); + } + + // Load alpha and beta + __m512i selector1 = _mm512_set1_epi32( alpha ); + __m512i selector2 = _mm512_set1_epi32( beta ); + + // Scale by alpha + c_int32_0p0 = _mm512_mullo_epi32( selector1, c_int32_0p0 ); + c_int32_0p1 = _mm512_mullo_epi32( selector1, c_int32_0p1 ); + + c_int32_1p0 = _mm512_mullo_epi32( selector1, c_int32_1p0 ); + c_int32_1p1 = _mm512_mullo_epi32( selector1, c_int32_1p1 ); + + c_int32_2p0 = _mm512_mullo_epi32( selector1, c_int32_2p0 ); + c_int32_2p1 = _mm512_mullo_epi32( selector1, c_int32_2p1 ); + + c_int32_3p0 = _mm512_mullo_epi32( selector1, c_int32_3p0 ); + c_int32_3p1 = _mm512_mullo_epi32( selector1, c_int32_3p1 ); + + c_int32_4p0 = _mm512_mullo_epi32( selector1, c_int32_4p0 ); + c_int32_4p1 = _mm512_mullo_epi32( selector1, c_int32_4p1 ); + + // Scale C by beta. + if ( beta != 0 ) + { + // c[0,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); + + // c[0, 16-31] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 1*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_0p1 = _mm512_add_epi32( selector1, c_int32_0p1 ); + + // c[1,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 1 ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); + + // c[1,16-31] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 1 ) + ( 1*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_1p1 = _mm512_add_epi32( selector1, c_int32_1p1 ); + + // c[2,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 2 ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_2p0 = _mm512_add_epi32( selector1, c_int32_2p0 ); + + // c[2,16-31] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 2 ) + ( 1*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_2p1 = _mm512_add_epi32( selector1, c_int32_2p1 ); + + // c[3,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 3 ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_3p0 = _mm512_add_epi32( selector1, c_int32_3p0 ); + + // c[3,16-31] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 3 ) + ( 1*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_3p1 = _mm512_add_epi32( selector1, c_int32_3p1 ); + + // c[4,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 4 ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_4p0 = _mm512_add_epi32( selector1, c_int32_4p0 ); + + // c[4,16-31] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 4 ) + ( 1*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_4p1 = _mm512_add_epi32( selector1, c_int32_4p1 ); + } + + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_5x32: + { + selector1 = + _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 1 * 16 ) ); + + // c[0,0-15] + c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); + + // c[0, 16-31] + c_int32_0p1 = _mm512_add_epi32( selector2, c_int32_0p1 ); + + // c[1,0-15] + c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); + + // c[1, 16-31] + c_int32_1p1 = _mm512_add_epi32( selector2, c_int32_1p1 ); + + // c[2,0-15] + c_int32_2p0 = _mm512_add_epi32( selector1, c_int32_2p0 ); + + // c[2, 16-31] + c_int32_2p1 = _mm512_add_epi32( selector2, c_int32_2p1 ); + + // c[3,0-15] + c_int32_3p0 = _mm512_add_epi32( selector1, c_int32_3p0 ); + + // c[3, 16-31] + c_int32_3p1 = _mm512_add_epi32( selector2, c_int32_3p1 ); + + // c[4,0-15] + c_int32_4p0 = _mm512_add_epi32( selector1, c_int32_4p0 ); + + // c[4, 16-31] + c_int32_4p1 = _mm512_add_epi32( selector2, c_int32_4p1 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_5x32: + { + selector1 = _mm512_setzero_epi32(); + + // c[0,0-15] + c_int32_0p0 = _mm512_max_epi32( selector1, c_int32_0p0 ); + + // c[0, 16-31] + c_int32_0p1 = _mm512_max_epi32( selector1, c_int32_0p1 ); + + // c[1,0-15] + c_int32_1p0 = _mm512_max_epi32( selector1, c_int32_1p0 ); + + // c[1,16-31] + c_int32_1p1 = _mm512_max_epi32( selector1, c_int32_1p1 ); + + // c[2,0-15] + c_int32_2p0 = _mm512_max_epi32( selector1, c_int32_2p0 ); + + // c[2,16-31] + c_int32_2p1 = _mm512_max_epi32( selector1, c_int32_2p1 ); + + // c[3,0-15] + c_int32_3p0 = _mm512_max_epi32( selector1, c_int32_3p0 ); + + // c[3,16-31] + c_int32_3p1 = _mm512_max_epi32( selector1, c_int32_3p1 ); + + // c[4,0-15] + c_int32_4p0 = _mm512_max_epi32( selector1, c_int32_4p0 ); + + // c[4,16-31] + c_int32_4p1 = _mm512_max_epi32( selector1, c_int32_4p1 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_5x32: + { + selector1 = _mm512_setzero_epi32(); + selector2 = + _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_0p0) + + // c[0, 16-31] + RELU_SCALE_OP_S32_AVX512(c_int32_0p1) + + // c[1, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_1p0) + + // c[1, 16-31] + RELU_SCALE_OP_S32_AVX512(c_int32_1p1) + + // c[2, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_2p0) + + // c[2, 16-31] + RELU_SCALE_OP_S32_AVX512(c_int32_2p1) + + // c[3, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_3p0) + + // c[3, 16-31] + RELU_SCALE_OP_S32_AVX512(c_int32_3p1) + + // c[4, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_4p0) + + // c[4, 16-31] + RELU_SCALE_OP_S32_AVX512(c_int32_4p1) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_DOWNSCALE_5x32: + { + selector1 = + _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j + ( 1 * 16 ) ); + + // c[0, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_0p0,selector1,0,0); + + // c[0, 16-31] + CVT_MULRND_CVT32_CVT8(c_int32_0p1,selector2,0,1); + + // c[1, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_1p0,selector1,1,0); + + // c[1, 16-31] + CVT_MULRND_CVT32_CVT8(c_int32_1p1,selector2,1,1); + + // c[2, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_2p0,selector1,2,0); + + // c[2, 16-31] + CVT_MULRND_CVT32_CVT8(c_int32_2p1,selector2,2,1); + + // c[3, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_3p0,selector1,3,0); + + // c[3, 16-31] + CVT_MULRND_CVT32_CVT8(c_int32_3p1,selector2,3,1); + + // c[4, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_4p0,selector1,4,0); + + // c[4, 16-31] + CVT_MULRND_CVT32_CVT8(c_int32_4p1,selector2,4,1); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_5x32_DISABLE: + ; + + // Store the results. + // c[0,0-15] + _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 0*16 ), c_int32_0p0 ); + + // c[0, 16-31] + _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 1*16 ), c_int32_0p1 ); + + // c[1,0-15] + _mm512_storeu_epi32( c + ( rs_c * 1 ) + ( 0*16 ), c_int32_1p0 ); + + // c[1,16-31] + _mm512_storeu_epi32( c + ( rs_c * 1 ) + ( 1*16 ), c_int32_1p1 ); + + // c[2,0-15] + _mm512_storeu_epi32( c + ( rs_c * 2 ) + ( 0*16 ), c_int32_2p0 ); + + // c[2,16-31] + _mm512_storeu_epi32( c + ( rs_c * 2 ) + ( 1*16 ), c_int32_2p1 ); + + // c[3,0-15] + _mm512_storeu_epi32( c + ( rs_c * 3 ) + ( 0*16 ), c_int32_3p0 ); + + // c[3,16-31] + _mm512_storeu_epi32( c + ( rs_c * 3 ) + ( 1*16 ), c_int32_3p1 ); + + // c[4,0-15] + _mm512_storeu_epi32( c + ( rs_c * 4 ) + ( 0*16 ), c_int32_4p0 ); + + // c[4,16-31] + _mm512_storeu_epi32( c + ( rs_c * 4 ) + ( 1*16 ), c_int32_4p1 ); +} + +// 4x32 int8o32 kernel +LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_4x32) +{ + static void* post_ops_labels[] = + { + &&POST_OPS_4x32_DISABLE, + &&POST_OPS_BIAS_4x32, + &&POST_OPS_RELU_4x32, + &&POST_OPS_RELU_SCALE_4x32, + &&POST_OPS_DOWNSCALE_4x32 + }; + dim_t k_full_pieces = k0 / 4; + dim_t k_partial_pieces = k0 % 4; + + uint32_t a_kfringe_buf = 0; + + // B matrix storage. + __m512i b0; + __m512i b1; + + // A matrix storage. + __m512i a_int32_0; + + // Registers to use for accumulating C. + __m512i c_int32_0p0 = _mm512_setzero_epi32(); + __m512i c_int32_0p1 = _mm512_setzero_epi32(); + + __m512i c_int32_1p0 = _mm512_setzero_epi32(); + __m512i c_int32_1p1 = _mm512_setzero_epi32(); + + __m512i c_int32_2p0 = _mm512_setzero_epi32(); + __m512i c_int32_2p1 = _mm512_setzero_epi32(); + + __m512i c_int32_3p0 = _mm512_setzero_epi32(); + __m512i c_int32_3p1 = _mm512_setzero_epi32(); + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + b0 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + b1 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 1 ) ); + + // Broadcast a[0,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[0,0-31] = a[0,kr:kr+4]*b[kr:kr+4,0-31] + c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); + c_int32_0p1 = _mm512_dpbusd_epi32( c_int32_0p1, a_int32_0, b1 ); + + // Broadcast a[1,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[1,0-31] = a[1,kr:kr+4]*b[kr:kr+4,0-31] + c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_0, b0 ); + c_int32_1p1 = _mm512_dpbusd_epi32( c_int32_1p1, a_int32_0, b1 ); + + // Broadcast a[2,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[2,0-31] = a[2,kr:kr+4]*b[kr:kr+4,0-31] + c_int32_2p0 = _mm512_dpbusd_epi32( c_int32_2p0, a_int32_0, b0 ); + c_int32_2p1 = _mm512_dpbusd_epi32( c_int32_2p1, a_int32_0, b1 ); + + // Broadcast a[3,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 3 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[3,0-31] = a[3,kr:kr+4]*b[kr:kr+4,0-31] + c_int32_3p0 = _mm512_dpbusd_epi32( c_int32_3p0, a_int32_0, b0 ); + c_int32_3p1 = _mm512_dpbusd_epi32( c_int32_3p1, a_int32_0, b1 ); + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + b0 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + b1 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); + + // Broadcast a[0,kr:kr+4]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[0,0-31] = a[0,kr:kr+4]*b[kr:kr+4,0-31] + c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); + c_int32_0p1 = _mm512_dpbusd_epi32( c_int32_0p1, a_int32_0, b1 ); + + // Broadcast a[1,kr:kr+4]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[1,0-31] = a[1,kr:kr+4]*b[kr:kr+4,0-31] + c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_0, b0 ); + c_int32_1p1 = _mm512_dpbusd_epi32( c_int32_1p1, a_int32_0, b1 ); + + // Broadcast a[2,kr:kr+4]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 2 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[2,0-31] = a[2,kr:kr+4]*b[kr:kr+4,0-31] + c_int32_2p0 = _mm512_dpbusd_epi32( c_int32_2p0, a_int32_0, b0 ); + c_int32_2p1 = _mm512_dpbusd_epi32( c_int32_2p1, a_int32_0, b1 ); + + // Broadcast a[3,kr:kr+4]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 3 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[3,0-31] = a[3,kr:kr+4]*b[kr:kr+4,0-31] + c_int32_3p0 = _mm512_dpbusd_epi32( c_int32_3p0, a_int32_0, b0 ); + c_int32_3p1 = _mm512_dpbusd_epi32( c_int32_3p1, a_int32_0, b1 ); + } + + // Load alpha and beta + __m512i selector1 = _mm512_set1_epi32( alpha ); + __m512i selector2 = _mm512_set1_epi32( beta ); + + // Scale by alpha + c_int32_0p0 = _mm512_mullo_epi32( selector1, c_int32_0p0 ); + c_int32_0p1 = _mm512_mullo_epi32( selector1, c_int32_0p1 ); + + c_int32_1p0 = _mm512_mullo_epi32( selector1, c_int32_1p0 ); + c_int32_1p1 = _mm512_mullo_epi32( selector1, c_int32_1p1 ); + + c_int32_2p0 = _mm512_mullo_epi32( selector1, c_int32_2p0 ); + c_int32_2p1 = _mm512_mullo_epi32( selector1, c_int32_2p1 ); + + c_int32_3p0 = _mm512_mullo_epi32( selector1, c_int32_3p0 ); + c_int32_3p1 = _mm512_mullo_epi32( selector1, c_int32_3p1 ); + + // Scale C by beta. + if ( beta != 0 ) + { + // c[0,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); + + // c[0, 16-31] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 1*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_0p1 = _mm512_add_epi32( selector1, c_int32_0p1 ); + + // c[1,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 1 ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); + + // c[1,16-31] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 1 ) + ( 1*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_1p1 = _mm512_add_epi32( selector1, c_int32_1p1 ); + + // c[2,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 2 ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_2p0 = _mm512_add_epi32( selector1, c_int32_2p0 ); + + // c[2,16-31] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 2 ) + ( 1*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_2p1 = _mm512_add_epi32( selector1, c_int32_2p1 ); + + // c[3,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 3 ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_3p0 = _mm512_add_epi32( selector1, c_int32_3p0 ); + + // c[3,16-31] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 3 ) + ( 1*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_3p1 = _mm512_add_epi32( selector1, c_int32_3p1 ); + } + + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_4x32: + { + selector1 = + _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 1 * 16 ) ); + + // c[0,0-15] + c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); + + // c[0, 16-31] + c_int32_0p1 = _mm512_add_epi32( selector2, c_int32_0p1 ); + + // c[1,0-15] + c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); + + // c[1, 16-31] + c_int32_1p1 = _mm512_add_epi32( selector2, c_int32_1p1 ); + + // c[2,0-15] + c_int32_2p0 = _mm512_add_epi32( selector1, c_int32_2p0 ); + + // c[2, 16-31] + c_int32_2p1 = _mm512_add_epi32( selector2, c_int32_2p1 ); + + // c[3,0-15] + c_int32_3p0 = _mm512_add_epi32( selector1, c_int32_3p0 ); + + // c[3, 16-31] + c_int32_3p1 = _mm512_add_epi32( selector2, c_int32_3p1 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_4x32: + { + selector1 = _mm512_setzero_epi32(); + + // c[0,0-15] + c_int32_0p0 = _mm512_max_epi32( selector1, c_int32_0p0 ); + + // c[0, 16-31] + c_int32_0p1 = _mm512_max_epi32( selector1, c_int32_0p1 ); + + // c[1,0-15] + c_int32_1p0 = _mm512_max_epi32( selector1, c_int32_1p0 ); + + // c[1,16-31] + c_int32_1p1 = _mm512_max_epi32( selector1, c_int32_1p1 ); + + // c[2,0-15] + c_int32_2p0 = _mm512_max_epi32( selector1, c_int32_2p0 ); + + // c[2,16-31] + c_int32_2p1 = _mm512_max_epi32( selector1, c_int32_2p1 ); + + // c[3,0-15] + c_int32_3p0 = _mm512_max_epi32( selector1, c_int32_3p0 ); + + // c[3,16-31] + c_int32_3p1 = _mm512_max_epi32( selector1, c_int32_3p1 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_4x32: + { + selector1 = _mm512_setzero_epi32(); + selector2 = + _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_0p0) + + // c[0, 16-31] + RELU_SCALE_OP_S32_AVX512(c_int32_0p1) + + // c[1, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_1p0) + + // c[1, 16-31] + RELU_SCALE_OP_S32_AVX512(c_int32_1p1) + + // c[2, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_2p0) + + // c[2, 16-31] + RELU_SCALE_OP_S32_AVX512(c_int32_2p1) + + // c[3, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_3p0) + + // c[3, 16-31] + RELU_SCALE_OP_S32_AVX512(c_int32_3p1) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_DOWNSCALE_4x32: + { + selector1 = + _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j + ( 1 * 16 ) ); + + // c[0, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_0p0,selector1,0,0); + + // c[0, 16-31] + CVT_MULRND_CVT32_CVT8(c_int32_0p1,selector2,0,1); + + // c[1, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_1p0,selector1,1,0); + + // c[1, 16-31] + CVT_MULRND_CVT32_CVT8(c_int32_1p1,selector2,1,1); + + // c[2, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_2p0,selector1,2,0); + + // c[2, 16-31] + CVT_MULRND_CVT32_CVT8(c_int32_2p1,selector2,2,1); + + // c[3, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_3p0,selector1,3,0); + + // c[3, 16-31] + CVT_MULRND_CVT32_CVT8(c_int32_3p1,selector2,3,1); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_4x32_DISABLE: + ; + + // Store the results. + // c[0,0-15] + _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 0*16 ), c_int32_0p0 ); + + // c[0, 16-31] + _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 1*16 ), c_int32_0p1 ); + + // c[1,0-15] + _mm512_storeu_epi32( c + ( rs_c * 1 ) + ( 0*16 ), c_int32_1p0 ); + + // c[1,16-31] + _mm512_storeu_epi32( c + ( rs_c * 1 ) + ( 1*16 ), c_int32_1p1 ); + + // c[2,0-15] + _mm512_storeu_epi32( c + ( rs_c * 2 ) + ( 0*16 ), c_int32_2p0 ); + + // c[2,16-31] + _mm512_storeu_epi32( c + ( rs_c * 2 ) + ( 1*16 ), c_int32_2p1 ); + + // c[3,0-15] + _mm512_storeu_epi32( c + ( rs_c * 3 ) + ( 0*16 ), c_int32_3p0 ); + + // c[3,16-31] + _mm512_storeu_epi32( c + ( rs_c * 3 ) + ( 1*16 ), c_int32_3p1 ); +} + +// 3x32 int8o32 kernel +LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_3x32) +{ + static void* post_ops_labels[] = + { + &&POST_OPS_3x32_DISABLE, + &&POST_OPS_BIAS_3x32, + &&POST_OPS_RELU_3x32, + &&POST_OPS_RELU_SCALE_3x32, + &&POST_OPS_DOWNSCALE_3x32 + }; + dim_t k_full_pieces = k0 / 4; + dim_t k_partial_pieces = k0 % 4; + + uint32_t a_kfringe_buf = 0; + + // B matrix storage. + __m512i b0; + __m512i b1; + + // A matrix storage. + __m512i a_int32_0; + + // Registers to use for accumulating C. + __m512i c_int32_0p0 = _mm512_setzero_epi32(); + __m512i c_int32_0p1 = _mm512_setzero_epi32(); + + __m512i c_int32_1p0 = _mm512_setzero_epi32(); + __m512i c_int32_1p1 = _mm512_setzero_epi32(); + + __m512i c_int32_2p0 = _mm512_setzero_epi32(); + __m512i c_int32_2p1 = _mm512_setzero_epi32(); + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + b0 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + b1 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 1 ) ); + + // Broadcast a[0,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[0,0-31] = a[0,kr:kr+4]*b[kr:kr+4,0-31] + c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); + c_int32_0p1 = _mm512_dpbusd_epi32( c_int32_0p1, a_int32_0, b1 ); + + // Broadcast a[1,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[1,0-31] = a[1,kr:kr+4]*b[kr:kr+4,0-31] + c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_0, b0 ); + c_int32_1p1 = _mm512_dpbusd_epi32( c_int32_1p1, a_int32_0, b1 ); + + // Broadcast a[2,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[2,0-31] = a[2,kr:kr+4]*b[kr:kr+4,0-31] + c_int32_2p0 = _mm512_dpbusd_epi32( c_int32_2p0, a_int32_0, b0 ); + c_int32_2p1 = _mm512_dpbusd_epi32( c_int32_2p1, a_int32_0, b1 ); + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + b0 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + b1 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); + + // Broadcast a[0,kr:kr+4]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[0,0-31] = a[0,kr:kr+4]*b[kr:kr+4,0-31] + c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); + c_int32_0p1 = _mm512_dpbusd_epi32( c_int32_0p1, a_int32_0, b1 ); + + // Broadcast a[1,kr:kr+4]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[1,0-31] = a[1,kr:kr+4]*b[kr:kr+4,0-31] + c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_0, b0 ); + c_int32_1p1 = _mm512_dpbusd_epi32( c_int32_1p1, a_int32_0, b1 ); + + // Broadcast a[2,kr:kr+4]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 2 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[2,0-31] = a[2,kr:kr+4]*b[kr:kr+4,0-31] + c_int32_2p0 = _mm512_dpbusd_epi32( c_int32_2p0, a_int32_0, b0 ); + c_int32_2p1 = _mm512_dpbusd_epi32( c_int32_2p1, a_int32_0, b1 ); + } + + // Load alpha and beta + __m512i selector1 = _mm512_set1_epi32( alpha ); + __m512i selector2 = _mm512_set1_epi32( beta ); + + // Scale by alpha + c_int32_0p0 = _mm512_mullo_epi32( selector1, c_int32_0p0 ); + c_int32_0p1 = _mm512_mullo_epi32( selector1, c_int32_0p1 ); + + c_int32_1p0 = _mm512_mullo_epi32( selector1, c_int32_1p0 ); + c_int32_1p1 = _mm512_mullo_epi32( selector1, c_int32_1p1 ); + + c_int32_2p0 = _mm512_mullo_epi32( selector1, c_int32_2p0 ); + c_int32_2p1 = _mm512_mullo_epi32( selector1, c_int32_2p1 ); + + // Scale C by beta. + if ( beta != 0 ) + { + // c[0,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); + + // c[0, 16-31] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 1*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_0p1 = _mm512_add_epi32( selector1, c_int32_0p1 ); + + // c[1,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 1 ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); + + // c[1,16-31] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 1 ) + ( 1*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_1p1 = _mm512_add_epi32( selector1, c_int32_1p1 ); + + // c[2,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 2 ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_2p0 = _mm512_add_epi32( selector1, c_int32_2p0 ); + + // c[2,16-31] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 2 ) + ( 1*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_2p1 = _mm512_add_epi32( selector1, c_int32_2p1 ); + } + + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_3x32: + { + selector1 = + _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 1 * 16 ) ); + + // c[0,0-15] + c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); + + // c[0, 16-31] + c_int32_0p1 = _mm512_add_epi32( selector2, c_int32_0p1 ); + + // c[1,0-15] + c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); + + // c[1, 16-31] + c_int32_1p1 = _mm512_add_epi32( selector2, c_int32_1p1 ); + + // c[2,0-15] + c_int32_2p0 = _mm512_add_epi32( selector1, c_int32_2p0 ); + + // c[2, 16-31] + c_int32_2p1 = _mm512_add_epi32( selector2, c_int32_2p1 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_3x32: + { + selector1 = _mm512_setzero_epi32(); + + // c[0,0-15] + c_int32_0p0 = _mm512_max_epi32( selector1, c_int32_0p0 ); + + // c[0, 16-31] + c_int32_0p1 = _mm512_max_epi32( selector1, c_int32_0p1 ); + + // c[1,0-15] + c_int32_1p0 = _mm512_max_epi32( selector1, c_int32_1p0 ); + + // c[1,16-31] + c_int32_1p1 = _mm512_max_epi32( selector1, c_int32_1p1 ); + + // c[2,0-15] + c_int32_2p0 = _mm512_max_epi32( selector1, c_int32_2p0 ); + + // c[2,16-31] + c_int32_2p1 = _mm512_max_epi32( selector1, c_int32_2p1 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_3x32: + { + selector1 = _mm512_setzero_epi32(); + selector2 = + _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_0p0) + + // c[0, 16-31] + RELU_SCALE_OP_S32_AVX512(c_int32_0p1) + + // c[1, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_1p0) + + // c[1, 16-31] + RELU_SCALE_OP_S32_AVX512(c_int32_1p1) + + // c[2, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_2p0) + + // c[2, 16-31] + RELU_SCALE_OP_S32_AVX512(c_int32_2p1) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_DOWNSCALE_3x32: + { + selector1 = + _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j + ( 1 * 16 ) ); + + // c[0, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_0p0,selector1,0,0); + + // c[0, 16-31] + CVT_MULRND_CVT32_CVT8(c_int32_0p1,selector2,0,1); + + // c[1, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_1p0,selector1,1,0); + + // c[1, 16-31] + CVT_MULRND_CVT32_CVT8(c_int32_1p1,selector2,1,1); + + // c[2, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_2p0,selector1,2,0); + + // c[2, 16-31] + CVT_MULRND_CVT32_CVT8(c_int32_2p1,selector2,2,1); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_3x32_DISABLE: + ; + + // Store the results. + // c[0,0-15] + _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 0*16 ), c_int32_0p0 ); + + // c[0, 16-31] + _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 1*16 ), c_int32_0p1 ); + + // c[1,0-15] + _mm512_storeu_epi32( c + ( rs_c * 1 ) + ( 0*16 ), c_int32_1p0 ); + + // c[1,16-31] + _mm512_storeu_epi32( c + ( rs_c * 1 ) + ( 1*16 ), c_int32_1p1 ); + + // c[2,0-15] + _mm512_storeu_epi32( c + ( rs_c * 2 ) + ( 0*16 ), c_int32_2p0 ); + + // c[2,16-31] + _mm512_storeu_epi32( c + ( rs_c * 2 ) + ( 1*16 ), c_int32_2p1 ); +} + +// 2x32 int8o32 kernel +LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_2x32) +{ + static void* post_ops_labels[] = + { + &&POST_OPS_2x32_DISABLE, + &&POST_OPS_BIAS_2x32, + &&POST_OPS_RELU_2x32, + &&POST_OPS_RELU_SCALE_2x32, + &&POST_OPS_DOWNSCALE_2x32 + }; + dim_t k_full_pieces = k0 / 4; + dim_t k_partial_pieces = k0 % 4; + + uint32_t a_kfringe_buf = 0; + + // B matrix storage. + __m512i b0; + __m512i b1; + + // A matrix storage. + __m512i a_int32_0; + + // Registers to use for accumulating C. + __m512i c_int32_0p0 = _mm512_setzero_epi32(); + __m512i c_int32_0p1 = _mm512_setzero_epi32(); + + __m512i c_int32_1p0 = _mm512_setzero_epi32(); + __m512i c_int32_1p1 = _mm512_setzero_epi32(); + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + b0 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + b1 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 1 ) ); + + // Broadcast a[0,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[0,0-31] = a[0,kr:kr+4]*b[kr:kr+4,0-31] + c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); + c_int32_0p1 = _mm512_dpbusd_epi32( c_int32_0p1, a_int32_0, b1 ); + + // Broadcast a[1,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[1,0-31] = a[1,kr:kr+4]*b[kr:kr+4,0-31] + c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_0, b0 ); + c_int32_1p1 = _mm512_dpbusd_epi32( c_int32_1p1, a_int32_0, b1 ); + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + b0 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + b1 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); + + // Broadcast a[0,kr:kr+4]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[0,0-31] = a[0,kr:kr+4]*b[kr:kr+4,0-31] + c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); + c_int32_0p1 = _mm512_dpbusd_epi32( c_int32_0p1, a_int32_0, b1 ); + + // Broadcast a[1,kr:kr+4]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[1,0-31] = a[1,kr:kr+4]*b[kr:kr+4,0-31] + c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_0, b0 ); + c_int32_1p1 = _mm512_dpbusd_epi32( c_int32_1p1, a_int32_0, b1 ); + } + + // Load alpha and beta + __m512i selector1 = _mm512_set1_epi32( alpha ); + __m512i selector2 = _mm512_set1_epi32( beta ); + + // Scale by alpha + c_int32_0p0 = _mm512_mullo_epi32( selector1, c_int32_0p0 ); + c_int32_0p1 = _mm512_mullo_epi32( selector1, c_int32_0p1 ); + + c_int32_1p0 = _mm512_mullo_epi32( selector1, c_int32_1p0 ); + c_int32_1p1 = _mm512_mullo_epi32( selector1, c_int32_1p1 ); + + // Scale C by beta. + if ( beta != 0 ) + { + // c[0,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); + + // c[0, 16-31] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 1*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_0p1 = _mm512_add_epi32( selector1, c_int32_0p1 ); + + // c[1,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 1 ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); + + // c[1,16-31] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 1 ) + ( 1*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_1p1 = _mm512_add_epi32( selector1, c_int32_1p1 ); + } + + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_2x32: + { + selector1 = + _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 1 * 16 ) ); + + // c[0,0-15] + c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); + + // c[0, 16-31] + c_int32_0p1 = _mm512_add_epi32( selector2, c_int32_0p1 ); + + // c[1,0-15] + c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); + + // c[1, 16-31] + c_int32_1p1 = _mm512_add_epi32( selector2, c_int32_1p1 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_2x32: + { + selector1 = _mm512_setzero_epi32(); + + // c[0,0-15] + c_int32_0p0 = _mm512_max_epi32( selector1, c_int32_0p0 ); + + // c[0, 16-31] + c_int32_0p1 = _mm512_max_epi32( selector1, c_int32_0p1 ); + + // c[1,0-15] + c_int32_1p0 = _mm512_max_epi32( selector1, c_int32_1p0 ); + + // c[1,16-31] + c_int32_1p1 = _mm512_max_epi32( selector1, c_int32_1p1 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_2x32: + { + selector1 = _mm512_setzero_epi32(); + selector2 = + _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_0p0) + + // c[0, 16-31] + RELU_SCALE_OP_S32_AVX512(c_int32_0p1) + + // c[1, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_1p0) + + // c[1, 16-31] + RELU_SCALE_OP_S32_AVX512(c_int32_1p1) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_DOWNSCALE_2x32: + { + selector1 = + _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j + ( 1 * 16 ) ); + + // c[0, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_0p0,selector1,0,0); + + // c[0, 16-31] + CVT_MULRND_CVT32_CVT8(c_int32_0p1,selector2,0,1); + + // c[1, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_1p0,selector1,1,0); + + // c[1, 16-31] + CVT_MULRND_CVT32_CVT8(c_int32_1p1,selector2,1,1); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_2x32_DISABLE: + ; + + // Store the results. + // c[0,0-15] + _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 0*16 ), c_int32_0p0 ); + + // c[0, 16-31] + _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 1*16 ), c_int32_0p1 ); + + // c[1,0-15] + _mm512_storeu_epi32( c + ( rs_c * 1 ) + ( 0*16 ), c_int32_1p0 ); + + // c[1,16-31] + _mm512_storeu_epi32( c + ( rs_c * 1 ) + ( 1*16 ), c_int32_1p1 ); +} + +// 1x32 int8o32 kernel +LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_1x32) +{ + static void* post_ops_labels[] = + { + &&POST_OPS_1x32_DISABLE, + &&POST_OPS_BIAS_1x32, + &&POST_OPS_RELU_1x32, + &&POST_OPS_RELU_SCALE_1x32, + &&POST_OPS_DOWNSCALE_1x32 + }; + dim_t k_full_pieces = k0 / 4; + dim_t k_partial_pieces = k0 % 4; + + uint32_t a_kfringe_buf = 0; + + // B matrix storage. + __m512i b0; + __m512i b1; + + // A matrix storage. + __m512i a_int32_0; + + // Registers to use for accumulating C. + __m512i c_int32_0p0 = _mm512_setzero_epi32(); + __m512i c_int32_0p1 = _mm512_setzero_epi32(); + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + b0 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + b1 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 1 ) ); + + // Broadcast a[0,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[0,0-31] = a[0,kr:kr+4]*b[kr:kr+4,0-31] + c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); + c_int32_0p1 = _mm512_dpbusd_epi32( c_int32_0p1, a_int32_0, b1 ); + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + b0 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + b1 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); + + // Broadcast a[0,kr:kr+4]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[0,0-31] = a[0,kr:kr+4]*b[kr:kr+4,0-31] + c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); + c_int32_0p1 = _mm512_dpbusd_epi32( c_int32_0p1, a_int32_0, b1 ); + } + + // Load alpha and beta + __m512i selector1 = _mm512_set1_epi32( alpha ); + __m512i selector2 = _mm512_set1_epi32( beta ); + + // Scale by alpha + c_int32_0p0 = _mm512_mullo_epi32( selector1, c_int32_0p0 ); + c_int32_0p1 = _mm512_mullo_epi32( selector1, c_int32_0p1 ); + + // Scale C by beta. + if ( beta != 0 ) + { + // c[0,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); + + // c[0, 16-31] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 1*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_0p1 = _mm512_add_epi32( selector1, c_int32_0p1 ); + } + + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_1x32: + { + selector1 = + _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 1 * 16 ) ); + + // c[0,0-15] + c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); + + // c[0, 16-31] + c_int32_0p1 = _mm512_add_epi32( selector2, c_int32_0p1 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_1x32: + { + selector1 = _mm512_setzero_epi32(); + + // c[0,0-15] + c_int32_0p0 = _mm512_max_epi32( selector1, c_int32_0p0 ); + + // c[0, 16-31] + c_int32_0p1 = _mm512_max_epi32( selector1, c_int32_0p1 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_1x32: + { + selector1 = _mm512_setzero_epi32(); + selector2 = + _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_0p0) + + // c[0, 16-31] + RELU_SCALE_OP_S32_AVX512(c_int32_0p1) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_DOWNSCALE_1x32: + { + selector1 = + _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j + ( 1 * 16 ) ); + + // c[0, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_0p0,selector1,0,0); + + // c[0, 16-31] + CVT_MULRND_CVT32_CVT8(c_int32_0p1,selector2,0,1); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_1x32_DISABLE: + ; + + // Store the results. + // c[0,0-15] + _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 0*16 ), c_int32_0p0 ); + + // c[0, 16-31] + _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 1*16 ), c_int32_0p1 ); +} + +// 5x48 int8o32 kernel +LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_5x48) +{ + static void* post_ops_labels[] = + { + &&POST_OPS_5x48_DISABLE, + &&POST_OPS_BIAS_5x48, + &&POST_OPS_RELU_5x48, + &&POST_OPS_RELU_SCALE_5x48, + &&POST_OPS_DOWNSCALE_5x48 + }; + dim_t k_full_pieces = k0 / 4; + dim_t k_partial_pieces = k0 % 4; + + uint32_t a_kfringe_buf = 0; + + // B matrix storage. + __m512i b0; + __m512i b1; + __m512i b2; + + // A matrix storage. + __m512i a_int32_0; + + // Registers to use for accumulating C. + __m512i c_int32_0p0 = _mm512_setzero_epi32(); + __m512i c_int32_0p1 = _mm512_setzero_epi32(); + __m512i c_int32_0p2 = _mm512_setzero_epi32(); + + __m512i c_int32_1p0 = _mm512_setzero_epi32(); + __m512i c_int32_1p1 = _mm512_setzero_epi32(); + __m512i c_int32_1p2 = _mm512_setzero_epi32(); + + __m512i c_int32_2p0 = _mm512_setzero_epi32(); + __m512i c_int32_2p1 = _mm512_setzero_epi32(); + __m512i c_int32_2p2 = _mm512_setzero_epi32(); + + __m512i c_int32_3p0 = _mm512_setzero_epi32(); + __m512i c_int32_3p1 = _mm512_setzero_epi32(); + __m512i c_int32_3p2 = _mm512_setzero_epi32(); + + __m512i c_int32_4p0 = _mm512_setzero_epi32(); + __m512i c_int32_4p1 = _mm512_setzero_epi32(); + __m512i c_int32_4p2 = _mm512_setzero_epi32(); + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + b0 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + b1 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 1 ) ); + b2 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 2 ) ); + + // Broadcast a[0,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[0,0-47] = a[0,kr:kr+4]*b[kr:kr+4,0-47] + c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); + c_int32_0p1 = _mm512_dpbusd_epi32( c_int32_0p1, a_int32_0, b1 ); + c_int32_0p2 = _mm512_dpbusd_epi32( c_int32_0p2, a_int32_0, b2 ); + + // Broadcast a[1,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[1,0-47] = a[1,kr:kr+4]*b[kr:kr+4,0-47] + c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_0, b0 ); + c_int32_1p1 = _mm512_dpbusd_epi32( c_int32_1p1, a_int32_0, b1 ); + c_int32_1p2 = _mm512_dpbusd_epi32( c_int32_1p2, a_int32_0, b2 ); + + // Broadcast a[2,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[2,0-47] = a[2,kr:kr+4]*b[kr:kr+4,0-47] + c_int32_2p0 = _mm512_dpbusd_epi32( c_int32_2p0, a_int32_0, b0 ); + c_int32_2p1 = _mm512_dpbusd_epi32( c_int32_2p1, a_int32_0, b1 ); + c_int32_2p2 = _mm512_dpbusd_epi32( c_int32_2p2, a_int32_0, b2 ); + + // Broadcast a[3,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 3 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[3,0-47] = a[3,kr:kr+4]*b[kr:kr+4,0-47] + c_int32_3p0 = _mm512_dpbusd_epi32( c_int32_3p0, a_int32_0, b0 ); + c_int32_3p1 = _mm512_dpbusd_epi32( c_int32_3p1, a_int32_0, b1 ); + c_int32_3p2 = _mm512_dpbusd_epi32( c_int32_3p2, a_int32_0, b2 ); + + // Broadcast a[4,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 4 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[4,0-47] = a[4,kr:kr+4]*b[kr:kr+4,0-47] + c_int32_4p0 = _mm512_dpbusd_epi32( c_int32_4p0, a_int32_0, b0 ); + c_int32_4p1 = _mm512_dpbusd_epi32( c_int32_4p1, a_int32_0, b1 ); + c_int32_4p2 = _mm512_dpbusd_epi32( c_int32_4p2, a_int32_0, b2 ); + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + b0 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + b1 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); + b2 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 2 ) ); + + // Broadcast a[0,kr:kr+4]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[0,0-47] = a[0,kr:kr+4]*b[kr:kr+4,0-47] + c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); + c_int32_0p1 = _mm512_dpbusd_epi32( c_int32_0p1, a_int32_0, b1 ); + c_int32_0p2 = _mm512_dpbusd_epi32( c_int32_0p2, a_int32_0, b2 ); + + // Broadcast a[1,kr:kr+4]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[1,0-47] = a[1,kr:kr+4]*b[kr:kr+4,0-47] + c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_0, b0 ); + c_int32_1p1 = _mm512_dpbusd_epi32( c_int32_1p1, a_int32_0, b1 ); + c_int32_1p2 = _mm512_dpbusd_epi32( c_int32_1p2, a_int32_0, b2 ); + + // Broadcast a[2,kr:kr+4]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 2 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[2,0-47] = a[2,kr:kr+4]*b[kr:kr+4,0-47] + c_int32_2p0 = _mm512_dpbusd_epi32( c_int32_2p0, a_int32_0, b0 ); + c_int32_2p1 = _mm512_dpbusd_epi32( c_int32_2p1, a_int32_0, b1 ); + c_int32_2p2 = _mm512_dpbusd_epi32( c_int32_2p2, a_int32_0, b2 ); + + // Broadcast a[3,kr:kr+4]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 3 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[3,0-47] = a[3,kr:kr+4]*b[kr:kr+4,0-47] + c_int32_3p0 = _mm512_dpbusd_epi32( c_int32_3p0, a_int32_0, b0 ); + c_int32_3p1 = _mm512_dpbusd_epi32( c_int32_3p1, a_int32_0, b1 ); + c_int32_3p2 = _mm512_dpbusd_epi32( c_int32_3p2, a_int32_0, b2 ); + + // Broadcast a[4,kr:kr+4]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 4 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[4,0-47] = a[4,kr:kr+4]*b[kr:kr+4,0-47] + c_int32_4p0 = _mm512_dpbusd_epi32( c_int32_4p0, a_int32_0, b0 ); + c_int32_4p1 = _mm512_dpbusd_epi32( c_int32_4p1, a_int32_0, b1 ); + c_int32_4p2 = _mm512_dpbusd_epi32( c_int32_4p2, a_int32_0, b2 ); + } + + // Load alpha and beta + __m512i selector1 = _mm512_set1_epi32( alpha ); + __m512i selector2 = _mm512_set1_epi32( beta ); + + // Scale by alpha + c_int32_0p0 = _mm512_mullo_epi32( selector1, c_int32_0p0 ); + c_int32_0p1 = _mm512_mullo_epi32( selector1, c_int32_0p1 ); + c_int32_0p2 = _mm512_mullo_epi32( selector1, c_int32_0p2 ); + + c_int32_1p0 = _mm512_mullo_epi32( selector1, c_int32_1p0 ); + c_int32_1p1 = _mm512_mullo_epi32( selector1, c_int32_1p1 ); + c_int32_1p2 = _mm512_mullo_epi32( selector1, c_int32_1p2 ); + + c_int32_2p0 = _mm512_mullo_epi32( selector1, c_int32_2p0 ); + c_int32_2p1 = _mm512_mullo_epi32( selector1, c_int32_2p1 ); + c_int32_2p2 = _mm512_mullo_epi32( selector1, c_int32_2p2 ); + + c_int32_3p0 = _mm512_mullo_epi32( selector1, c_int32_3p0 ); + c_int32_3p1 = _mm512_mullo_epi32( selector1, c_int32_3p1 ); + c_int32_3p2 = _mm512_mullo_epi32( selector1, c_int32_3p2 ); + + c_int32_4p0 = _mm512_mullo_epi32( selector1, c_int32_4p0 ); + c_int32_4p1 = _mm512_mullo_epi32( selector1, c_int32_4p1 ); + c_int32_4p2 = _mm512_mullo_epi32( selector1, c_int32_4p2 ); + + // Scale C by beta. + if ( beta != 0 ) + { + // c[0,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); + + // c[0, 16-31] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 1*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_0p1 = _mm512_add_epi32( selector1, c_int32_0p1 ); + + // c[0,32-47] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 2*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_0p2 = _mm512_add_epi32( selector1, c_int32_0p2 ); + + // c[1,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 1 ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); + + // c[1,16-31] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 1 ) + ( 1*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_1p1 = _mm512_add_epi32( selector1, c_int32_1p1 ); + + // c[1,32-47] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 1 ) + ( 2*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_1p2 = _mm512_add_epi32( selector1, c_int32_1p2 ); + + // c[2,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 2 ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_2p0 = _mm512_add_epi32( selector1, c_int32_2p0 ); + + // c[2,16-31] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 2 ) + ( 1*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_2p1 = _mm512_add_epi32( selector1, c_int32_2p1 ); + + // c[2,32-47] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 2 ) + ( 2*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_2p2 = _mm512_add_epi32( selector1, c_int32_2p2 ); + + // c[3,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 3 ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_3p0 = _mm512_add_epi32( selector1, c_int32_3p0 ); + + // c[3,16-31] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 3 ) + ( 1*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_3p1 = _mm512_add_epi32( selector1, c_int32_3p1 ); + + // c[3,32-47] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 3 ) + ( 2*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_3p2 = _mm512_add_epi32( selector1, c_int32_3p2 ); + + // c[4,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 4 ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_4p0 = _mm512_add_epi32( selector1, c_int32_4p0 ); + + // c[4,16-31] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 4 ) + ( 1*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_4p1 = _mm512_add_epi32( selector1, c_int32_4p1 ); + + // c[4,32-47] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 4 ) + ( 2*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_4p2 = _mm512_add_epi32( selector1, c_int32_4p2 ); + } + + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_5x48: + { + selector1 = + _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 1 * 16 ) ); + a_int32_0 = + _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 2 * 16 ) ); + + // c[0,0-15] + c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); + + // c[0, 16-31] + c_int32_0p1 = _mm512_add_epi32( selector2, c_int32_0p1 ); + + // c[0,32-47] + c_int32_0p2 = _mm512_add_epi32( a_int32_0, c_int32_0p2 ); + + // c[1,0-15] + c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); + + // c[1, 16-31] + c_int32_1p1 = _mm512_add_epi32( selector2, c_int32_1p1 ); + + // c[1,32-47] + c_int32_1p2 = _mm512_add_epi32( a_int32_0, c_int32_1p2 ); + + // c[2,0-15] + c_int32_2p0 = _mm512_add_epi32( selector1, c_int32_2p0 ); + + // c[2, 16-31] + c_int32_2p1 = _mm512_add_epi32( selector2, c_int32_2p1 ); + + // c[2,32-47] + c_int32_2p2 = _mm512_add_epi32( a_int32_0, c_int32_2p2 ); + + // c[3,0-15] + c_int32_3p0 = _mm512_add_epi32( selector1, c_int32_3p0 ); + + // c[3, 16-31] + c_int32_3p1 = _mm512_add_epi32( selector2, c_int32_3p1 ); + + // c[3,32-47] + c_int32_3p2 = _mm512_add_epi32( a_int32_0, c_int32_3p2 ); + + // c[4,0-15] + c_int32_4p0 = _mm512_add_epi32( selector1, c_int32_4p0 ); + + // c[4, 16-31] + c_int32_4p1 = _mm512_add_epi32( selector2, c_int32_4p1 ); + + // c[4,32-47] + c_int32_4p2 = _mm512_add_epi32( a_int32_0, c_int32_4p2 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_5x48: + { + selector1 = _mm512_setzero_epi32(); + + // c[0,0-15] + c_int32_0p0 = _mm512_max_epi32( selector1, c_int32_0p0 ); + + // c[0, 16-31] + c_int32_0p1 = _mm512_max_epi32( selector1, c_int32_0p1 ); + + // c[0,32-47] + c_int32_0p2 = _mm512_max_epi32( selector1, c_int32_0p2 ); + + // c[1,0-15] + c_int32_1p0 = _mm512_max_epi32( selector1, c_int32_1p0 ); + + // c[1,16-31] + c_int32_1p1 = _mm512_max_epi32( selector1, c_int32_1p1 ); + + // c[1,32-47] + c_int32_1p2 = _mm512_max_epi32( selector1, c_int32_1p2 ); + + // c[2,0-15] + c_int32_2p0 = _mm512_max_epi32( selector1, c_int32_2p0 ); + + // c[2,16-31] + c_int32_2p1 = _mm512_max_epi32( selector1, c_int32_2p1 ); + + // c[2,32-47] + c_int32_2p2 = _mm512_max_epi32( selector1, c_int32_2p2 ); + + // c[3,0-15] + c_int32_3p0 = _mm512_max_epi32( selector1, c_int32_3p0 ); + + // c[3,16-31] + c_int32_3p1 = _mm512_max_epi32( selector1, c_int32_3p1 ); + + // c[3,32-47] + c_int32_3p2 = _mm512_max_epi32( selector1, c_int32_3p2 ); + + // c[4,0-15] + c_int32_4p0 = _mm512_max_epi32( selector1, c_int32_4p0 ); + + // c[4,16-31] + c_int32_4p1 = _mm512_max_epi32( selector1, c_int32_4p1 ); + + // c[4,32-47] + c_int32_4p2 = _mm512_max_epi32( selector1, c_int32_4p2 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_5x48: + { + selector1 = _mm512_setzero_epi32(); + selector2 = + _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_0p0) + + // c[0, 16-31] + RELU_SCALE_OP_S32_AVX512(c_int32_0p1) + + // c[0, 32-47] + RELU_SCALE_OP_S32_AVX512(c_int32_0p2) + + // c[1, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_1p0) + + // c[1, 16-31] + RELU_SCALE_OP_S32_AVX512(c_int32_1p1) + + // c[1, 32-47] + RELU_SCALE_OP_S32_AVX512(c_int32_1p2) + + // c[2, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_2p0) + + // c[2, 16-31] + RELU_SCALE_OP_S32_AVX512(c_int32_2p1) + + // c[2, 32-47] + RELU_SCALE_OP_S32_AVX512(c_int32_2p2) + + // c[3, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_3p0) + + // c[3, 16-31] + RELU_SCALE_OP_S32_AVX512(c_int32_3p1) + + // c[3, 32-47] + RELU_SCALE_OP_S32_AVX512(c_int32_3p2) + + // c[4, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_4p0) + + // c[4, 16-31] + RELU_SCALE_OP_S32_AVX512(c_int32_4p1) + + // c[4, 32-47] + RELU_SCALE_OP_S32_AVX512(c_int32_4p2) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_DOWNSCALE_5x48: + { + selector1 = + _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j + ( 1 * 16 ) ); + a_int32_0 = + _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j + ( 2 * 16 ) ); + + // c[0, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_0p0,selector1,0,0); + + // c[0, 16-31] + CVT_MULRND_CVT32_CVT8(c_int32_0p1,selector2,0,1); + + // c[0, 32-47] + CVT_MULRND_CVT32_CVT8(c_int32_0p2,a_int32_0,0,2); + + // c[1, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_1p0,selector1,1,0); + + // c[1, 16-31] + CVT_MULRND_CVT32_CVT8(c_int32_1p1,selector2,1,1); + + // c[1, 32-47] + CVT_MULRND_CVT32_CVT8(c_int32_1p2,a_int32_0,1,2); + + // c[2, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_2p0,selector1,2,0); + + // c[2, 16-31] + CVT_MULRND_CVT32_CVT8(c_int32_2p1,selector2,2,1); + + // c[2, 32-47] + CVT_MULRND_CVT32_CVT8(c_int32_2p2,a_int32_0,2,2); + + // c[3, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_3p0,selector1,3,0); + + // c[3, 16-31] + CVT_MULRND_CVT32_CVT8(c_int32_3p1,selector2,3,1); + + // c[3, 32-47] + CVT_MULRND_CVT32_CVT8(c_int32_3p2,a_int32_0,3,2); + + // c[4, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_4p0,selector1,4,0); + + // c[4, 16-31] + CVT_MULRND_CVT32_CVT8(c_int32_4p1,selector2,4,1); + + // c[4, 32-47] + CVT_MULRND_CVT32_CVT8(c_int32_4p2,a_int32_0,4,2); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_5x48_DISABLE: + ; + + // Store the results. + // c[0,0-15] + _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 0*16 ), c_int32_0p0 ); + + // c[0, 16-31] + _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 1*16 ), c_int32_0p1 ); + + // c[0,32-47] + _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 2*16 ), c_int32_0p2 ); + + // c[1,0-15] + _mm512_storeu_epi32( c + ( rs_c * 1 ) + ( 0*16 ), c_int32_1p0 ); + + // c[1,16-31] + _mm512_storeu_epi32( c + ( rs_c * 1 ) + ( 1*16 ), c_int32_1p1 ); + + // c[1,32-47] + _mm512_storeu_epi32( c + ( rs_c * 1 ) + ( 2*16 ), c_int32_1p2 ); + + // c[2,0-15] + _mm512_storeu_epi32( c + ( rs_c * 2 ) + ( 0*16 ), c_int32_2p0 ); + + // c[2,16-31] + _mm512_storeu_epi32( c + ( rs_c * 2 ) + ( 1*16 ), c_int32_2p1 ); + + // c[2,32-47] + _mm512_storeu_epi32( c + ( rs_c * 2 ) + ( 2*16 ), c_int32_2p2 ); + + // c[3,0-15] + _mm512_storeu_epi32( c + ( rs_c * 3 ) + ( 0*16 ), c_int32_3p0 ); + + // c[3,16-31] + _mm512_storeu_epi32( c + ( rs_c * 3 ) + ( 1*16 ), c_int32_3p1 ); + + // c[3,32-47] + _mm512_storeu_epi32( c + ( rs_c * 3 ) + ( 2*16 ), c_int32_3p2 ); + + // c[4,0-15] + _mm512_storeu_epi32( c + ( rs_c * 4 ) + ( 0*16 ), c_int32_4p0 ); + + // c[4,16-31] + _mm512_storeu_epi32( c + ( rs_c * 4 ) + ( 1*16 ), c_int32_4p1 ); + + // c[4,32-47] + _mm512_storeu_epi32( c + ( rs_c * 4 ) + ( 2*16 ), c_int32_4p2 ); +} + +// 4x48 int8o32 kernel +LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_4x48) +{ + static void* post_ops_labels[] = + { + &&POST_OPS_4x48_DISABLE, + &&POST_OPS_BIAS_4x48, + &&POST_OPS_RELU_4x48, + &&POST_OPS_RELU_SCALE_4x48, + &&POST_OPS_DOWNSCALE_4x48 + }; + dim_t k_full_pieces = k0 / 4; + dim_t k_partial_pieces = k0 % 4; + + uint32_t a_kfringe_buf = 0; + + // B matrix storage. + __m512i b0; + __m512i b1; + __m512i b2; + + // A matrix storage. + __m512i a_int32_0; + + // Registers to use for accumulating C. + __m512i c_int32_0p0 = _mm512_setzero_epi32(); + __m512i c_int32_0p1 = _mm512_setzero_epi32(); + __m512i c_int32_0p2 = _mm512_setzero_epi32(); + + __m512i c_int32_1p0 = _mm512_setzero_epi32(); + __m512i c_int32_1p1 = _mm512_setzero_epi32(); + __m512i c_int32_1p2 = _mm512_setzero_epi32(); + + __m512i c_int32_2p0 = _mm512_setzero_epi32(); + __m512i c_int32_2p1 = _mm512_setzero_epi32(); + __m512i c_int32_2p2 = _mm512_setzero_epi32(); + + __m512i c_int32_3p0 = _mm512_setzero_epi32(); + __m512i c_int32_3p1 = _mm512_setzero_epi32(); + __m512i c_int32_3p2 = _mm512_setzero_epi32(); + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + b0 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + b1 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 1 ) ); + b2 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 2 ) ); + + // Broadcast a[0,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[0,0-47] = a[0,kr:kr+4]*b[kr:kr+4,0-47] + c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); + c_int32_0p1 = _mm512_dpbusd_epi32( c_int32_0p1, a_int32_0, b1 ); + c_int32_0p2 = _mm512_dpbusd_epi32( c_int32_0p2, a_int32_0, b2 ); + + // Broadcast a[1,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[1,0-47] = a[1,kr:kr+4]*b[kr:kr+4,0-47] + c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_0, b0 ); + c_int32_1p1 = _mm512_dpbusd_epi32( c_int32_1p1, a_int32_0, b1 ); + c_int32_1p2 = _mm512_dpbusd_epi32( c_int32_1p2, a_int32_0, b2 ); + + // Broadcast a[2,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[2,0-47] = a[2,kr:kr+4]*b[kr:kr+4,0-47] + c_int32_2p0 = _mm512_dpbusd_epi32( c_int32_2p0, a_int32_0, b0 ); + c_int32_2p1 = _mm512_dpbusd_epi32( c_int32_2p1, a_int32_0, b1 ); + c_int32_2p2 = _mm512_dpbusd_epi32( c_int32_2p2, a_int32_0, b2 ); + + // Broadcast a[3,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 3 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[3,0-47] = a[3,kr:kr+4]*b[kr:kr+4,0-47] + c_int32_3p0 = _mm512_dpbusd_epi32( c_int32_3p0, a_int32_0, b0 ); + c_int32_3p1 = _mm512_dpbusd_epi32( c_int32_3p1, a_int32_0, b1 ); + c_int32_3p2 = _mm512_dpbusd_epi32( c_int32_3p2, a_int32_0, b2 ); + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + b0 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + b1 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); + b2 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 2 ) ); + + // Broadcast a[0,kr:kr+4]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[0,0-47] = a[0,kr:kr+4]*b[kr:kr+4,0-47] + c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); + c_int32_0p1 = _mm512_dpbusd_epi32( c_int32_0p1, a_int32_0, b1 ); + c_int32_0p2 = _mm512_dpbusd_epi32( c_int32_0p2, a_int32_0, b2 ); + + // Broadcast a[1,kr:kr+4]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[1,0-47] = a[1,kr:kr+4]*b[kr:kr+4,0-47] + c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_0, b0 ); + c_int32_1p1 = _mm512_dpbusd_epi32( c_int32_1p1, a_int32_0, b1 ); + c_int32_1p2 = _mm512_dpbusd_epi32( c_int32_1p2, a_int32_0, b2 ); + + // Broadcast a[2,kr:kr+4]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 2 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[2,0-47] = a[2,kr:kr+4]*b[kr:kr+4,0-47] + c_int32_2p0 = _mm512_dpbusd_epi32( c_int32_2p0, a_int32_0, b0 ); + c_int32_2p1 = _mm512_dpbusd_epi32( c_int32_2p1, a_int32_0, b1 ); + c_int32_2p2 = _mm512_dpbusd_epi32( c_int32_2p2, a_int32_0, b2 ); + + // Broadcast a[3,kr:kr+4]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 3 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[3,0-47] = a[3,kr:kr+4]*b[kr:kr+4,0-47] + c_int32_3p0 = _mm512_dpbusd_epi32( c_int32_3p0, a_int32_0, b0 ); + c_int32_3p1 = _mm512_dpbusd_epi32( c_int32_3p1, a_int32_0, b1 ); + c_int32_3p2 = _mm512_dpbusd_epi32( c_int32_3p2, a_int32_0, b2 ); + } + + // Load alpha and beta + __m512i selector1 = _mm512_set1_epi32( alpha ); + __m512i selector2 = _mm512_set1_epi32( beta ); + + // Scale by alpha + c_int32_0p0 = _mm512_mullo_epi32( selector1, c_int32_0p0 ); + c_int32_0p1 = _mm512_mullo_epi32( selector1, c_int32_0p1 ); + c_int32_0p2 = _mm512_mullo_epi32( selector1, c_int32_0p2 ); + + c_int32_1p0 = _mm512_mullo_epi32( selector1, c_int32_1p0 ); + c_int32_1p1 = _mm512_mullo_epi32( selector1, c_int32_1p1 ); + c_int32_1p2 = _mm512_mullo_epi32( selector1, c_int32_1p2 ); + + c_int32_2p0 = _mm512_mullo_epi32( selector1, c_int32_2p0 ); + c_int32_2p1 = _mm512_mullo_epi32( selector1, c_int32_2p1 ); + c_int32_2p2 = _mm512_mullo_epi32( selector1, c_int32_2p2 ); + + c_int32_3p0 = _mm512_mullo_epi32( selector1, c_int32_3p0 ); + c_int32_3p1 = _mm512_mullo_epi32( selector1, c_int32_3p1 ); + c_int32_3p2 = _mm512_mullo_epi32( selector1, c_int32_3p2 ); + + // Scale C by beta. + if ( beta != 0 ) + { + // c[0,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); + + // c[0, 16-31] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 1*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_0p1 = _mm512_add_epi32( selector1, c_int32_0p1 ); + + // c[0,32-47] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 2*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_0p2 = _mm512_add_epi32( selector1, c_int32_0p2 ); + + // c[1,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 1 ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); + + // c[1,16-31] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 1 ) + ( 1*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_1p1 = _mm512_add_epi32( selector1, c_int32_1p1 ); + + // c[1,32-47] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 1 ) + ( 2*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_1p2 = _mm512_add_epi32( selector1, c_int32_1p2 ); + + // c[2,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 2 ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_2p0 = _mm512_add_epi32( selector1, c_int32_2p0 ); + + // c[2,16-31] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 2 ) + ( 1*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_2p1 = _mm512_add_epi32( selector1, c_int32_2p1 ); + + // c[2,32-47] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 2 ) + ( 2*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_2p2 = _mm512_add_epi32( selector1, c_int32_2p2 ); + + // c[3,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 3 ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_3p0 = _mm512_add_epi32( selector1, c_int32_3p0 ); + + // c[3,16-31] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 3 ) + ( 1*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_3p1 = _mm512_add_epi32( selector1, c_int32_3p1 ); + + // c[3,32-47] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 3 ) + ( 2*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_3p2 = _mm512_add_epi32( selector1, c_int32_3p2 ); + } + + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_4x48: + { + selector1 = + _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 1 * 16 ) ); + a_int32_0 = + _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 2 * 16 ) ); + + // c[0,0-15] + c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); + + // c[0, 16-31] + c_int32_0p1 = _mm512_add_epi32( selector2, c_int32_0p1 ); + + // c[0,32-47] + c_int32_0p2 = _mm512_add_epi32( a_int32_0, c_int32_0p2 ); + + // c[1,0-15] + c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); + + // c[1, 16-31] + c_int32_1p1 = _mm512_add_epi32( selector2, c_int32_1p1 ); + + // c[1,32-47] + c_int32_1p2 = _mm512_add_epi32( a_int32_0, c_int32_1p2 ); + + // c[2,0-15] + c_int32_2p0 = _mm512_add_epi32( selector1, c_int32_2p0 ); + + // c[2, 16-31] + c_int32_2p1 = _mm512_add_epi32( selector2, c_int32_2p1 ); + + // c[2,32-47] + c_int32_2p2 = _mm512_add_epi32( a_int32_0, c_int32_2p2 ); + + // c[3,0-15] + c_int32_3p0 = _mm512_add_epi32( selector1, c_int32_3p0 ); + + // c[3, 16-31] + c_int32_3p1 = _mm512_add_epi32( selector2, c_int32_3p1 ); + + // c[3,32-47] + c_int32_3p2 = _mm512_add_epi32( a_int32_0, c_int32_3p2 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_4x48: + { + selector1 = _mm512_setzero_epi32(); + + // c[0,0-15] + c_int32_0p0 = _mm512_max_epi32( selector1, c_int32_0p0 ); + + // c[0, 16-31] + c_int32_0p1 = _mm512_max_epi32( selector1, c_int32_0p1 ); + + // c[0,32-47] + c_int32_0p2 = _mm512_max_epi32( selector1, c_int32_0p2 ); + + // c[1,0-15] + c_int32_1p0 = _mm512_max_epi32( selector1, c_int32_1p0 ); + + // c[1,16-31] + c_int32_1p1 = _mm512_max_epi32( selector1, c_int32_1p1 ); + + // c[1,32-47] + c_int32_1p2 = _mm512_max_epi32( selector1, c_int32_1p2 ); + + // c[2,0-15] + c_int32_2p0 = _mm512_max_epi32( selector1, c_int32_2p0 ); + + // c[2,16-31] + c_int32_2p1 = _mm512_max_epi32( selector1, c_int32_2p1 ); + + // c[2,32-47] + c_int32_2p2 = _mm512_max_epi32( selector1, c_int32_2p2 ); + + // c[3,0-15] + c_int32_3p0 = _mm512_max_epi32( selector1, c_int32_3p0 ); + + // c[3,16-31] + c_int32_3p1 = _mm512_max_epi32( selector1, c_int32_3p1 ); + + // c[3,32-47] + c_int32_3p2 = _mm512_max_epi32( selector1, c_int32_3p2 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_4x48: + { + selector1 = _mm512_setzero_epi32(); + selector2 = + _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_0p0) + + // c[0, 16-31] + RELU_SCALE_OP_S32_AVX512(c_int32_0p1) + + // c[0, 32-47] + RELU_SCALE_OP_S32_AVX512(c_int32_0p2) + + // c[1, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_1p0) + + // c[1, 16-31] + RELU_SCALE_OP_S32_AVX512(c_int32_1p1) + + // c[1, 32-47] + RELU_SCALE_OP_S32_AVX512(c_int32_1p2) + + // c[2, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_2p0) + + // c[2, 16-31] + RELU_SCALE_OP_S32_AVX512(c_int32_2p1) + + // c[2, 32-47] + RELU_SCALE_OP_S32_AVX512(c_int32_2p2) + + // c[3, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_3p0) + + // c[3, 16-31] + RELU_SCALE_OP_S32_AVX512(c_int32_3p1) + + // c[3, 32-47] + RELU_SCALE_OP_S32_AVX512(c_int32_3p2) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_DOWNSCALE_4x48: + { + selector1 = + _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j + ( 1 * 16 ) ); + a_int32_0 = + _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j + ( 2 * 16 ) ); + + // c[0, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_0p0,selector1,0,0); + + // c[0, 16-31] + CVT_MULRND_CVT32_CVT8(c_int32_0p1,selector2,0,1); + + // c[0, 32-47] + CVT_MULRND_CVT32_CVT8(c_int32_0p2,a_int32_0,0,2); + + // c[1, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_1p0,selector1,1,0); + + // c[1, 16-31] + CVT_MULRND_CVT32_CVT8(c_int32_1p1,selector2,1,1); + + // c[1, 32-47] + CVT_MULRND_CVT32_CVT8(c_int32_1p2,a_int32_0,1,2); + + // c[2, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_2p0,selector1,2,0); + + // c[2, 16-31] + CVT_MULRND_CVT32_CVT8(c_int32_2p1,selector2,2,1); + + // c[2, 32-47] + CVT_MULRND_CVT32_CVT8(c_int32_2p2,a_int32_0,2,2); + + // c[3, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_3p0,selector1,3,0); + + // c[3, 16-31] + CVT_MULRND_CVT32_CVT8(c_int32_3p1,selector2,3,1); + + // c[3, 32-47] + CVT_MULRND_CVT32_CVT8(c_int32_3p2,a_int32_0,3,2); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_4x48_DISABLE: + ; + + // Store the results. + // c[0,0-15] + _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 0*16 ), c_int32_0p0 ); + + // c[0, 16-31] + _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 1*16 ), c_int32_0p1 ); + + // c[0,32-47] + _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 2*16 ), c_int32_0p2 ); + + // c[1,0-15] + _mm512_storeu_epi32( c + ( rs_c * 1 ) + ( 0*16 ), c_int32_1p0 ); + + // c[1,16-31] + _mm512_storeu_epi32( c + ( rs_c * 1 ) + ( 1*16 ), c_int32_1p1 ); + + // c[1,32-47] + _mm512_storeu_epi32( c + ( rs_c * 1 ) + ( 2*16 ), c_int32_1p2 ); + + // c[2,0-15] + _mm512_storeu_epi32( c + ( rs_c * 2 ) + ( 0*16 ), c_int32_2p0 ); + + // c[2,16-31] + _mm512_storeu_epi32( c + ( rs_c * 2 ) + ( 1*16 ), c_int32_2p1 ); + + // c[2,32-47] + _mm512_storeu_epi32( c + ( rs_c * 2 ) + ( 2*16 ), c_int32_2p2 ); + + // c[3,0-15] + _mm512_storeu_epi32( c + ( rs_c * 3 ) + ( 0*16 ), c_int32_3p0 ); + + // c[3,16-31] + _mm512_storeu_epi32( c + ( rs_c * 3 ) + ( 1*16 ), c_int32_3p1 ); + + // c[3,32-47] + _mm512_storeu_epi32( c + ( rs_c * 3 ) + ( 2*16 ), c_int32_3p2 ); +} + +// 3x48 int8o32 kernel +LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_3x48) +{ + static void* post_ops_labels[] = + { + &&POST_OPS_3x48_DISABLE, + &&POST_OPS_BIAS_3x48, + &&POST_OPS_RELU_3x48, + &&POST_OPS_RELU_SCALE_3x48, + &&POST_OPS_DOWNSCALE_3x48 + }; + dim_t k_full_pieces = k0 / 4; + dim_t k_partial_pieces = k0 % 4; + + uint32_t a_kfringe_buf = 0; + + // B matrix storage. + __m512i b0; + __m512i b1; + __m512i b2; + + // A matrix storage. + __m512i a_int32_0; + + // Registers to use for accumulating C. + __m512i c_int32_0p0 = _mm512_setzero_epi32(); + __m512i c_int32_0p1 = _mm512_setzero_epi32(); + __m512i c_int32_0p2 = _mm512_setzero_epi32(); + + __m512i c_int32_1p0 = _mm512_setzero_epi32(); + __m512i c_int32_1p1 = _mm512_setzero_epi32(); + __m512i c_int32_1p2 = _mm512_setzero_epi32(); + + __m512i c_int32_2p0 = _mm512_setzero_epi32(); + __m512i c_int32_2p1 = _mm512_setzero_epi32(); + __m512i c_int32_2p2 = _mm512_setzero_epi32(); + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + b0 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + b1 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 1 ) ); + b2 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 2 ) ); + + // Broadcast a[0,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[0,0-47] = a[0,kr:kr+4]*b[kr:kr+4,0-47] + c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); + c_int32_0p1 = _mm512_dpbusd_epi32( c_int32_0p1, a_int32_0, b1 ); + c_int32_0p2 = _mm512_dpbusd_epi32( c_int32_0p2, a_int32_0, b2 ); + + // Broadcast a[1,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[1,0-47] = a[1,kr:kr+4]*b[kr:kr+4,0-47] + c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_0, b0 ); + c_int32_1p1 = _mm512_dpbusd_epi32( c_int32_1p1, a_int32_0, b1 ); + c_int32_1p2 = _mm512_dpbusd_epi32( c_int32_1p2, a_int32_0, b2 ); + + // Broadcast a[2,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[2,0-47] = a[2,kr:kr+4]*b[kr:kr+4,0-47] + c_int32_2p0 = _mm512_dpbusd_epi32( c_int32_2p0, a_int32_0, b0 ); + c_int32_2p1 = _mm512_dpbusd_epi32( c_int32_2p1, a_int32_0, b1 ); + c_int32_2p2 = _mm512_dpbusd_epi32( c_int32_2p2, a_int32_0, b2 ); + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + b0 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + b1 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); + b2 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 2 ) ); + + // Broadcast a[0,kr:kr+4]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[0,0-47] = a[0,kr:kr+4]*b[kr:kr+4,0-47] + c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); + c_int32_0p1 = _mm512_dpbusd_epi32( c_int32_0p1, a_int32_0, b1 ); + c_int32_0p2 = _mm512_dpbusd_epi32( c_int32_0p2, a_int32_0, b2 ); + + // Broadcast a[1,kr:kr+4]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[1,0-47] = a[1,kr:kr+4]*b[kr:kr+4,0-47] + c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_0, b0 ); + c_int32_1p1 = _mm512_dpbusd_epi32( c_int32_1p1, a_int32_0, b1 ); + c_int32_1p2 = _mm512_dpbusd_epi32( c_int32_1p2, a_int32_0, b2 ); + + // Broadcast a[2,kr:kr+4]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 2 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[2,0-47] = a[2,kr:kr+4]*b[kr:kr+4,0-47] + c_int32_2p0 = _mm512_dpbusd_epi32( c_int32_2p0, a_int32_0, b0 ); + c_int32_2p1 = _mm512_dpbusd_epi32( c_int32_2p1, a_int32_0, b1 ); + c_int32_2p2 = _mm512_dpbusd_epi32( c_int32_2p2, a_int32_0, b2 ); + } + + // Load alpha and beta + __m512i selector1 = _mm512_set1_epi32( alpha ); + __m512i selector2 = _mm512_set1_epi32( beta ); + + // Scale by alpha + c_int32_0p0 = _mm512_mullo_epi32( selector1, c_int32_0p0 ); + c_int32_0p1 = _mm512_mullo_epi32( selector1, c_int32_0p1 ); + c_int32_0p2 = _mm512_mullo_epi32( selector1, c_int32_0p2 ); + + c_int32_1p0 = _mm512_mullo_epi32( selector1, c_int32_1p0 ); + c_int32_1p1 = _mm512_mullo_epi32( selector1, c_int32_1p1 ); + c_int32_1p2 = _mm512_mullo_epi32( selector1, c_int32_1p2 ); + + c_int32_2p0 = _mm512_mullo_epi32( selector1, c_int32_2p0 ); + c_int32_2p1 = _mm512_mullo_epi32( selector1, c_int32_2p1 ); + c_int32_2p2 = _mm512_mullo_epi32( selector1, c_int32_2p2 ); + + // Scale C by beta. + if ( beta != 0 ) + { + // c[0,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); + + // c[0, 16-31] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 1*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_0p1 = _mm512_add_epi32( selector1, c_int32_0p1 ); + + // c[0,32-47] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 2*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_0p2 = _mm512_add_epi32( selector1, c_int32_0p2 ); + + // c[1,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 1 ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); + + // c[1,16-31] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 1 ) + ( 1*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_1p1 = _mm512_add_epi32( selector1, c_int32_1p1 ); + + // c[1,32-47] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 1 ) + ( 2*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_1p2 = _mm512_add_epi32( selector1, c_int32_1p2 ); + + // c[2,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 2 ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_2p0 = _mm512_add_epi32( selector1, c_int32_2p0 ); + + // c[2,16-31] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 2 ) + ( 1*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_2p1 = _mm512_add_epi32( selector1, c_int32_2p1 ); + + // c[2,32-47] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 2 ) + ( 2*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_2p2 = _mm512_add_epi32( selector1, c_int32_2p2 ); + } + + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_3x48: + { + selector1 = + _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 1 * 16 ) ); + a_int32_0 = + _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 2 * 16 ) ); + + // c[0,0-15] + c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); + + // c[0, 16-31] + c_int32_0p1 = _mm512_add_epi32( selector2, c_int32_0p1 ); + + // c[0,32-47] + c_int32_0p2 = _mm512_add_epi32( a_int32_0, c_int32_0p2 ); + + // c[1,0-15] + c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); + + // c[1, 16-31] + c_int32_1p1 = _mm512_add_epi32( selector2, c_int32_1p1 ); + + // c[1,32-47] + c_int32_1p2 = _mm512_add_epi32( a_int32_0, c_int32_1p2 ); + + // c[2,0-15] + c_int32_2p0 = _mm512_add_epi32( selector1, c_int32_2p0 ); + + // c[2, 16-31] + c_int32_2p1 = _mm512_add_epi32( selector2, c_int32_2p1 ); + + // c[2,32-47] + c_int32_2p2 = _mm512_add_epi32( a_int32_0, c_int32_2p2 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_3x48: + { + selector1 = _mm512_setzero_epi32(); + + // c[0,0-15] + c_int32_0p0 = _mm512_max_epi32( selector1, c_int32_0p0 ); + + // c[0, 16-31] + c_int32_0p1 = _mm512_max_epi32( selector1, c_int32_0p1 ); + + // c[0,32-47] + c_int32_0p2 = _mm512_max_epi32( selector1, c_int32_0p2 ); + + // c[1,0-15] + c_int32_1p0 = _mm512_max_epi32( selector1, c_int32_1p0 ); + + // c[1,16-31] + c_int32_1p1 = _mm512_max_epi32( selector1, c_int32_1p1 ); + + // c[1,32-47] + c_int32_1p2 = _mm512_max_epi32( selector1, c_int32_1p2 ); + + // c[2,0-15] + c_int32_2p0 = _mm512_max_epi32( selector1, c_int32_2p0 ); + + // c[2,16-31] + c_int32_2p1 = _mm512_max_epi32( selector1, c_int32_2p1 ); + + // c[2,32-47] + c_int32_2p2 = _mm512_max_epi32( selector1, c_int32_2p2 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_3x48: + { + selector1 = _mm512_setzero_epi32(); + selector2 = + _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_0p0) + + // c[0, 16-31] + RELU_SCALE_OP_S32_AVX512(c_int32_0p1) + + // c[0, 32-47] + RELU_SCALE_OP_S32_AVX512(c_int32_0p2) + + // c[1, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_1p0) + + // c[1, 16-31] + RELU_SCALE_OP_S32_AVX512(c_int32_1p1) + + // c[1, 32-47] + RELU_SCALE_OP_S32_AVX512(c_int32_1p2) + + // c[2, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_2p0) + + // c[2, 16-31] + RELU_SCALE_OP_S32_AVX512(c_int32_2p1) + + // c[2, 32-47] + RELU_SCALE_OP_S32_AVX512(c_int32_2p2) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_DOWNSCALE_3x48: + { + selector1 = + _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j + ( 1 * 16 ) ); + a_int32_0 = + _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j + ( 2 * 16 ) ); + + // c[0, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_0p0,selector1,0,0); + + // c[0, 16-31] + CVT_MULRND_CVT32_CVT8(c_int32_0p1,selector2,0,1); + + // c[0, 32-47] + CVT_MULRND_CVT32_CVT8(c_int32_0p2,a_int32_0,0,2); + + // c[1, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_1p0,selector1,1,0); + + // c[1, 16-31] + CVT_MULRND_CVT32_CVT8(c_int32_1p1,selector2,1,1); + + // c[1, 32-47] + CVT_MULRND_CVT32_CVT8(c_int32_1p2,a_int32_0,1,2); + + // c[2, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_2p0,selector1,2,0); + + // c[2, 16-31] + CVT_MULRND_CVT32_CVT8(c_int32_2p1,selector2,2,1); + + // c[2, 32-47] + CVT_MULRND_CVT32_CVT8(c_int32_2p2,a_int32_0,2,2); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_3x48_DISABLE: + ; + + // Store the results. + // c[0,0-15] + _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 0*16 ), c_int32_0p0 ); + + // c[0, 16-31] + _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 1*16 ), c_int32_0p1 ); + + // c[0,32-47] + _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 2*16 ), c_int32_0p2 ); + + // c[1,0-15] + _mm512_storeu_epi32( c + ( rs_c * 1 ) + ( 0*16 ), c_int32_1p0 ); + + // c[1,16-31] + _mm512_storeu_epi32( c + ( rs_c * 1 ) + ( 1*16 ), c_int32_1p1 ); + + // c[1,32-47] + _mm512_storeu_epi32( c + ( rs_c * 1 ) + ( 2*16 ), c_int32_1p2 ); + + // c[2,0-15] + _mm512_storeu_epi32( c + ( rs_c * 2 ) + ( 0*16 ), c_int32_2p0 ); + + // c[2,16-31] + _mm512_storeu_epi32( c + ( rs_c * 2 ) + ( 1*16 ), c_int32_2p1 ); + + // c[2,32-47] + _mm512_storeu_epi32( c + ( rs_c * 2 ) + ( 2*16 ), c_int32_2p2 ); +} + +// 2x48 int8o32 kernel +LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_2x48) +{ + static void* post_ops_labels[] = + { + &&POST_OPS_2x48_DISABLE, + &&POST_OPS_BIAS_2x48, + &&POST_OPS_RELU_2x48, + &&POST_OPS_RELU_SCALE_2x48, + &&POST_OPS_DOWNSCALE_2x48 + }; + dim_t k_full_pieces = k0 / 4; + dim_t k_partial_pieces = k0 % 4; + + uint32_t a_kfringe_buf = 0; + + // B matrix storage. + __m512i b0; + __m512i b1; + __m512i b2; + + // A matrix storage. + __m512i a_int32_0; + + // Registers to use for accumulating C. + __m512i c_int32_0p0 = _mm512_setzero_epi32(); + __m512i c_int32_0p1 = _mm512_setzero_epi32(); + __m512i c_int32_0p2 = _mm512_setzero_epi32(); + + __m512i c_int32_1p0 = _mm512_setzero_epi32(); + __m512i c_int32_1p1 = _mm512_setzero_epi32(); + __m512i c_int32_1p2 = _mm512_setzero_epi32(); + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + b0 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + b1 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 1 ) ); + b2 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 2 ) ); + + // Broadcast a[0,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[0,0-47] = a[0,kr:kr+4]*b[kr:kr+4,0-47] + c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); + c_int32_0p1 = _mm512_dpbusd_epi32( c_int32_0p1, a_int32_0, b1 ); + c_int32_0p2 = _mm512_dpbusd_epi32( c_int32_0p2, a_int32_0, b2 ); + + // Broadcast a[1,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[1,0-47] = a[1,kr:kr+4]*b[kr:kr+4,0-47] + c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_0, b0 ); + c_int32_1p1 = _mm512_dpbusd_epi32( c_int32_1p1, a_int32_0, b1 ); + c_int32_1p2 = _mm512_dpbusd_epi32( c_int32_1p2, a_int32_0, b2 ); + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + b0 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + b1 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); + b2 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 2 ) ); + + // Broadcast a[0,kr:kr+4]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[0,0-47] = a[0,kr:kr+4]*b[kr:kr+4,0-47] + c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); + c_int32_0p1 = _mm512_dpbusd_epi32( c_int32_0p1, a_int32_0, b1 ); + c_int32_0p2 = _mm512_dpbusd_epi32( c_int32_0p2, a_int32_0, b2 ); + + // Broadcast a[1,kr:kr+4]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[1,0-47] = a[1,kr:kr+4]*b[kr:kr+4,0-47] + c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_0, b0 ); + c_int32_1p1 = _mm512_dpbusd_epi32( c_int32_1p1, a_int32_0, b1 ); + c_int32_1p2 = _mm512_dpbusd_epi32( c_int32_1p2, a_int32_0, b2 ); + } + + // Load alpha and beta + __m512i selector1 = _mm512_set1_epi32( alpha ); + __m512i selector2 = _mm512_set1_epi32( beta ); + + // Scale by alpha + c_int32_0p0 = _mm512_mullo_epi32( selector1, c_int32_0p0 ); + c_int32_0p1 = _mm512_mullo_epi32( selector1, c_int32_0p1 ); + c_int32_0p2 = _mm512_mullo_epi32( selector1, c_int32_0p2 ); + + c_int32_1p0 = _mm512_mullo_epi32( selector1, c_int32_1p0 ); + c_int32_1p1 = _mm512_mullo_epi32( selector1, c_int32_1p1 ); + c_int32_1p2 = _mm512_mullo_epi32( selector1, c_int32_1p2 ); + + // Scale C by beta. + if ( beta != 0 ) + { + // c[0,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); + + // c[0, 16-31] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 1*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_0p1 = _mm512_add_epi32( selector1, c_int32_0p1 ); + + // c[0,32-47] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 2*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_0p2 = _mm512_add_epi32( selector1, c_int32_0p2 ); + + // c[1,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 1 ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); + + // c[1,16-31] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 1 ) + ( 1*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_1p1 = _mm512_add_epi32( selector1, c_int32_1p1 ); + + // c[1,32-47] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 1 ) + ( 2*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_1p2 = _mm512_add_epi32( selector1, c_int32_1p2 ); + } + + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_2x48: + { + selector1 = + _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 1 * 16 ) ); + a_int32_0 = + _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 2 * 16 ) ); + + // c[0,0-15] + c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); + + // c[0, 16-31] + c_int32_0p1 = _mm512_add_epi32( selector2, c_int32_0p1 ); + + // c[0,32-47] + c_int32_0p2 = _mm512_add_epi32( a_int32_0, c_int32_0p2 ); + + // c[1,0-15] + c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); + + // c[1, 16-31] + c_int32_1p1 = _mm512_add_epi32( selector2, c_int32_1p1 ); + + // c[1,32-47] + c_int32_1p2 = _mm512_add_epi32( a_int32_0, c_int32_1p2 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_2x48: + { + selector1 = _mm512_setzero_epi32(); + + // c[0,0-15] + c_int32_0p0 = _mm512_max_epi32( selector1, c_int32_0p0 ); + + // c[0, 16-31] + c_int32_0p1 = _mm512_max_epi32( selector1, c_int32_0p1 ); + + // c[0,32-47] + c_int32_0p2 = _mm512_max_epi32( selector1, c_int32_0p2 ); + + // c[1,0-15] + c_int32_1p0 = _mm512_max_epi32( selector1, c_int32_1p0 ); + + // c[1,16-31] + c_int32_1p1 = _mm512_max_epi32( selector1, c_int32_1p1 ); + + // c[1,32-47] + c_int32_1p2 = _mm512_max_epi32( selector1, c_int32_1p2 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_2x48: + { + selector1 = _mm512_setzero_epi32(); + selector2 = + _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_0p0) + + // c[0, 16-31] + RELU_SCALE_OP_S32_AVX512(c_int32_0p1) + + // c[0, 32-47] + RELU_SCALE_OP_S32_AVX512(c_int32_0p2) + + // c[1, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_1p0) + + // c[1, 16-31] + RELU_SCALE_OP_S32_AVX512(c_int32_1p1) + + // c[1, 32-47] + RELU_SCALE_OP_S32_AVX512(c_int32_1p2) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_DOWNSCALE_2x48: + { + selector1 = + _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j + ( 1 * 16 ) ); + a_int32_0 = + _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j + ( 2 * 16 ) ); + + // c[0, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_0p0,selector1,0,0); + + // c[0, 16-31] + CVT_MULRND_CVT32_CVT8(c_int32_0p1,selector2,0,1); + + // c[0, 32-47] + CVT_MULRND_CVT32_CVT8(c_int32_0p2,a_int32_0,0,2); + + // c[1, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_1p0,selector1,1,0); + + // c[1, 16-31] + CVT_MULRND_CVT32_CVT8(c_int32_1p1,selector2,1,1); + + // c[1, 32-47] + CVT_MULRND_CVT32_CVT8(c_int32_1p2,a_int32_0,1,2); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_2x48_DISABLE: + ; + + // Store the results. + // c[0,0-15] + _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 0*16 ), c_int32_0p0 ); + + // c[0, 16-31] + _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 1*16 ), c_int32_0p1 ); + + // c[0,32-47] + _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 2*16 ), c_int32_0p2 ); + + // c[1,0-15] + _mm512_storeu_epi32( c + ( rs_c * 1 ) + ( 0*16 ), c_int32_1p0 ); + + // c[1,16-31] + _mm512_storeu_epi32( c + ( rs_c * 1 ) + ( 1*16 ), c_int32_1p1 ); + + // c[1,32-47] + _mm512_storeu_epi32( c + ( rs_c * 1 ) + ( 2*16 ), c_int32_1p2 ); +} + +// 1x48 int8o32 kernel +LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_1x48) +{ + static void* post_ops_labels[] = + { + &&POST_OPS_1x48_DISABLE, + &&POST_OPS_BIAS_1x48, + &&POST_OPS_RELU_1x48, + &&POST_OPS_RELU_SCALE_1x48, + &&POST_OPS_DOWNSCALE_1x48 + }; + dim_t k_full_pieces = k0 / 4; + dim_t k_partial_pieces = k0 % 4; + + uint32_t a_kfringe_buf = 0; + + // B matrix storage. + __m512i b0; + __m512i b1; + __m512i b2; + + // A matrix storage. + __m512i a_int32_0; + + // Registers to use for accumulating C. + __m512i c_int32_0p0 = _mm512_setzero_epi32(); + __m512i c_int32_0p1 = _mm512_setzero_epi32(); + __m512i c_int32_0p2 = _mm512_setzero_epi32(); + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + b0 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + b1 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 1 ) ); + b2 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 2 ) ); + + // Broadcast a[0,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[0,0-47] = a[0,kr:kr+4]*b[kr:kr+4,0-47] + c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); + c_int32_0p1 = _mm512_dpbusd_epi32( c_int32_0p1, a_int32_0, b1 ); + c_int32_0p2 = _mm512_dpbusd_epi32( c_int32_0p2, a_int32_0, b2 ); + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + b0 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + b1 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); + b2 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 2 ) ); + + // Broadcast a[0,kr:kr+4]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[0,0-47] = a[0,kr:kr+4]*b[kr:kr+4,0-47] + c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); + c_int32_0p1 = _mm512_dpbusd_epi32( c_int32_0p1, a_int32_0, b1 ); + c_int32_0p2 = _mm512_dpbusd_epi32( c_int32_0p2, a_int32_0, b2 ); + } + + // Load alpha and beta + __m512i selector1 = _mm512_set1_epi32( alpha ); + __m512i selector2 = _mm512_set1_epi32( beta ); + + // Scale by alpha + c_int32_0p0 = _mm512_mullo_epi32( selector1, c_int32_0p0 ); + c_int32_0p1 = _mm512_mullo_epi32( selector1, c_int32_0p1 ); + c_int32_0p2 = _mm512_mullo_epi32( selector1, c_int32_0p2 ); + + // Scale C by beta. + if ( beta != 0 ) + { + // c[0,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); + + // c[0, 16-31] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 1*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_0p1 = _mm512_add_epi32( selector1, c_int32_0p1 ); + + // c[0,32-47] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 2*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_0p2 = _mm512_add_epi32( selector1, c_int32_0p2 ); + } + + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_1x48: + { + selector1 = + _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 1 * 16 ) ); + a_int32_0 = + _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 2 * 16 ) ); + + // c[0,0-15] + c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); + + // c[0, 16-31] + c_int32_0p1 = _mm512_add_epi32( selector2, c_int32_0p1 ); + + // c[0,32-47] + c_int32_0p2 = _mm512_add_epi32( a_int32_0, c_int32_0p2 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_1x48: + { + selector1 = _mm512_setzero_epi32(); + + // c[0,0-15] + c_int32_0p0 = _mm512_max_epi32( selector1, c_int32_0p0 ); + + // c[0, 16-31] + c_int32_0p1 = _mm512_max_epi32( selector1, c_int32_0p1 ); + + // c[0,32-47] + c_int32_0p2 = _mm512_max_epi32( selector1, c_int32_0p2 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_1x48: + { + selector1 = _mm512_setzero_epi32(); + selector2 = + _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_0p0) + + // c[0, 16-31] + RELU_SCALE_OP_S32_AVX512(c_int32_0p1) + + // c[0, 32-47] + RELU_SCALE_OP_S32_AVX512(c_int32_0p2) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_DOWNSCALE_1x48: + { + selector1 = + _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j + ( 1 * 16 ) ); + a_int32_0 = + _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j + ( 2 * 16 ) ); + + // c[0, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_0p0,selector1,0,0); + + // c[0, 16-31] + CVT_MULRND_CVT32_CVT8(c_int32_0p1,selector2,0,1); + + // c[0, 32-47] + CVT_MULRND_CVT32_CVT8(c_int32_0p2,a_int32_0,0,2); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_1x48_DISABLE: + ; + + // Store the results. + // c[0,0-15] + _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 0*16 ), c_int32_0p0 ); + + // c[0, 16-31] + _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 1*16 ), c_int32_0p1 ); + + // c[0,32-47] + _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 2*16 ), c_int32_0p2 ); +} +#endif diff --git a/addon/aocl_gemm/kernels/u8s8s32/lpgemm_n_fringe_amd512vnni.c b/addon/aocl_gemm/kernels/u8s8s32/lpgemm_n_fringe_amd512vnni.c new file mode 100644 index 0000000000..856dc1355e --- /dev/null +++ b/addon/aocl_gemm/kernels/u8s8s32/lpgemm_n_fringe_amd512vnni.c @@ -0,0 +1,2300 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include + +#include "blis.h" +#include "lpgemm_kernels.h" +#include "lpgemm_s32_kern_macros.h" + +#ifdef BLIS_KERNELS_ZEN4 +// 6xlt16 int8o32 fringe kernel +LPGEMM_N_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6xlt16) +{ + static void* post_ops_labels[] = + { + &&POST_OPS_6xLT16_DISABLE, + &&POST_OPS_BIAS_6xLT16, + &&POST_OPS_RELU_6xLT16, + &&POST_OPS_RELU_SCALE_6xLT16, + &&POST_OPS_DOWNSCALE_6xLT16 + }; + dim_t MR = 6; + dim_t m_full_pieces = m0 / MR; + dim_t m_full_pieces_loop_limit = m_full_pieces * MR; + dim_t m_partial_pieces = m0 % MR; + + dim_t k_full_pieces = k0 / 4; + dim_t k_partial_pieces = k0 % 4; + + uint32_t a_kfringe_buf = 0; + + // B matrix storage. + __m512i b0; + + // A matrix storage. + __m512i a_int32_0; + + // For corner cases. + int32_t buf0[16]; + int32_t buf1[16]; + int32_t buf2[16]; + int32_t buf3[16]; + int32_t buf4[16]; + int32_t buf5[16]; + + for ( dim_t ir = 0; ir < m_full_pieces_loop_limit; ir += MR ) + { + // Registers to use for accumulating C. + __m512i c_int32_0p0 = _mm512_setzero_epi32(); + + __m512i c_int32_1p0 = _mm512_setzero_epi32(); + + __m512i c_int32_2p0 = _mm512_setzero_epi32(); + + __m512i c_int32_3p0 = _mm512_setzero_epi32(); + + __m512i c_int32_4p0 = _mm512_setzero_epi32(); + + __m512i c_int32_5p0 = _mm512_setzero_epi32(); + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + // Load 4 rows with 16 extended elements each from B to 1 ZMM + // registers. It is to be noted that the B matrix is packed for use + // in vnni instructions and each load to ZMM register will have 4 + // elements along k direction and 16 elements across n directions, + // so 4x16 elements to a ZMM register. + b0 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[0,0-15] = a[0,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); + + // Broadcast a[1,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[1,0-15] = a[1,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_0, b0 ); + + // Broadcast a[2,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[2,0-15] = a[2,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_2p0 = _mm512_dpbusd_epi32( c_int32_2p0, a_int32_0, b0 ); + + // Broadcast a[3,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 3 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[3,0-15] = a[3,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_3p0 = _mm512_dpbusd_epi32( c_int32_3p0, a_int32_0, b0 ); + + // Broadcast a[4,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 4 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[4,0-15] = a[4,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_4p0 = _mm512_dpbusd_epi32( c_int32_4p0, a_int32_0, b0 ); + + // Broadcast a[5,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 5 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[5,0-15] = a[5,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_5p0 = _mm512_dpbusd_epi32( c_int32_5p0, a_int32_0, b0 ); + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + b0 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+4]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( uint8_t ) ) + ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[0,0-15] = a[0,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); + + // Broadcast a[1,kr:kr+4]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( uint8_t ) ) + ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[1,0-15] = a[1,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_0, b0 ); + + // Broadcast a[2,kr:kr+4]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 2 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( uint8_t ) ) + ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[2,0-15] = a[2,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_2p0 = _mm512_dpbusd_epi32( c_int32_2p0, a_int32_0, b0 ); + + // Broadcast a[3,kr:kr+4]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 3 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( uint8_t ) ) + ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[3,0-15] = a[3,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_3p0 = _mm512_dpbusd_epi32( c_int32_3p0, a_int32_0, b0 ); + + // Broadcast a[4,kr:kr+4]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 4 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( uint8_t ) ) + ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[4,0-15] = a[4,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_4p0 = _mm512_dpbusd_epi32( c_int32_4p0, a_int32_0, b0 ); + + // Broadcast a[5,kr:kr+4]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 5 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( uint8_t ) ) + ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[5,0-15] = a[5,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_5p0 = _mm512_dpbusd_epi32( c_int32_5p0, a_int32_0, b0 ); + } + + // Load alpha and beta + __m512i selector1 = _mm512_set1_epi32( alpha ); + __m512i selector2 = _mm512_set1_epi32( beta ); + + // Scale by alpha + c_int32_0p0 = _mm512_mullo_epi32( selector1, c_int32_0p0 ); + + c_int32_1p0 = _mm512_mullo_epi32( selector1, c_int32_1p0 ); + + c_int32_2p0 = _mm512_mullo_epi32( selector1, c_int32_2p0 ); + + c_int32_3p0 = _mm512_mullo_epi32( selector1, c_int32_3p0 ); + + c_int32_4p0 = _mm512_mullo_epi32( selector1, c_int32_4p0 ); + + c_int32_5p0 = _mm512_mullo_epi32( selector1, c_int32_5p0 ); + + // Scale C by beta. + if ( beta != 0 ) + { + memcpy( buf0, ( c + ( rs_c * ( ir + 0 ) ) ), ( n0_rem * sizeof( int32_t ) ) ); + memcpy( buf1, ( c + ( rs_c * ( ir + 1 ) ) ), ( n0_rem * sizeof( int32_t ) ) ); + memcpy( buf2, ( c + ( rs_c * ( ir + 2 ) ) ), ( n0_rem * sizeof( int32_t ) ) ); + memcpy( buf3, ( c + ( rs_c * ( ir + 3 ) ) ), ( n0_rem * sizeof( int32_t ) ) ); + memcpy( buf4, ( c + ( rs_c * ( ir + 4 ) ) ), ( n0_rem * sizeof( int32_t ) ) ); + memcpy( buf5, ( c + ( rs_c * ( ir + 5 ) ) ), ( n0_rem * sizeof( int32_t ) ) ); + + // c[0,0-15] + selector1 = _mm512_loadu_epi32( buf0 ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); + + // c[1,0-15] + selector1 = _mm512_loadu_epi32( buf1 ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); + + // c[2,0-15] + selector1 = _mm512_loadu_epi32( buf2 ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_2p0 = _mm512_add_epi32( selector1, c_int32_2p0 ); + + // c[3,0-15] + selector1 = _mm512_loadu_epi32( buf3 ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_3p0 = _mm512_add_epi32( selector1, c_int32_3p0 ); + + // c[4,0-15] + selector1 = _mm512_loadu_epi32( buf4 ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_4p0 = _mm512_add_epi32( selector1, c_int32_4p0 ); + + // c[5,0-15] + selector1 = _mm512_loadu_epi32( buf5 ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_5p0 = _mm512_add_epi32( selector1, c_int32_5p0 ); + } + + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_6xLT16: + { + memcpy( buf0, ( ( int32_t* )post_ops_list_temp->op_args1 + + post_op_c_j ), ( n0_rem * sizeof( int32_t ) ) ); + selector1 = _mm512_loadu_epi32( buf0 ); + + // c[0,0-15] + c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); + + // c[1,0-15] + c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); + + // c[2,0-15] + c_int32_2p0 = _mm512_add_epi32( selector1, c_int32_2p0 ); + + // c[3,0-15] + c_int32_3p0 = _mm512_add_epi32( selector1, c_int32_3p0 ); + + // c[4,0-15] + c_int32_4p0 = _mm512_add_epi32( selector1, c_int32_4p0 ); + + // c[5,0-15] + c_int32_5p0 = _mm512_add_epi32( selector1, c_int32_5p0 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_6xLT16: + { + selector1 = _mm512_setzero_epi32(); + + // c[0,0-15] + c_int32_0p0 = _mm512_max_epi32( selector1, c_int32_0p0 ); + + // c[1,0-15] + c_int32_1p0 = _mm512_max_epi32( selector1, c_int32_1p0 ); + + // c[2,0-15] + c_int32_2p0 = _mm512_max_epi32( selector1, c_int32_2p0 ); + + // c[3,0-15] + c_int32_3p0 = _mm512_max_epi32( selector1, c_int32_3p0 ); + + // c[4,0-15] + c_int32_4p0 = _mm512_max_epi32( selector1, c_int32_4p0 ); + + // c[5,0-15] + c_int32_5p0 = _mm512_max_epi32( selector1, c_int32_5p0 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_6xLT16: + { + selector1 = _mm512_setzero_epi32(); + selector2 = + _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_0p0) + + // c[1, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_1p0) + + // c[2, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_2p0) + + // c[3, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_3p0) + + // c[4, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_4p0) + + // c[5, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_5p0) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_DOWNSCALE_6xLT16: + { + memcpy( buf0, ( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j ), ( n0_rem * sizeof( float ) ) ); + selector1 = _mm512_loadu_epi32( buf0 ); + + // c[0, 0-15] + CVT_MULRND_CVT32_CVT8_LT16(c_int32_0p0,selector1,0,0); + + // c[1, 0-15] + CVT_MULRND_CVT32_CVT8_LT16(c_int32_1p0,selector1,1,0); + + // c[2, 0-15] + CVT_MULRND_CVT32_CVT8_LT16(c_int32_2p0,selector1,2,0); + + // c[3, 0-15] + CVT_MULRND_CVT32_CVT8_LT16(c_int32_3p0,selector1,3,0); + + // c[4, 0-15] + CVT_MULRND_CVT32_CVT8_LT16(c_int32_4p0,selector1,4,0); + + // c[5, 0-15] + CVT_MULRND_CVT32_CVT8_LT16(c_int32_5p0,selector1,5,0); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_6xLT16_DISABLE: + ; + + // Store the results. + // c[0,0-15] + _mm512_storeu_epi32( buf0, c_int32_0p0 ); + + // c[1,0-15] + _mm512_storeu_epi32( buf1, c_int32_1p0 ); + + // c[2,0-15] + _mm512_storeu_epi32( buf2, c_int32_2p0 ); + + // c[3,0-15] + _mm512_storeu_epi32( buf3, c_int32_3p0 ); + + // c[4,0-15] + _mm512_storeu_epi32( buf4, c_int32_4p0 ); + + // c[5,0-15] + _mm512_storeu_epi32( buf5, c_int32_5p0 ); + + // Memcpy partial parts. + // c[0,0-15] + memcpy( c + ( rs_c * ( ir + 0 ) ) + ( 0*16 ), buf0, ( n0_rem * sizeof( int32_t ) ) ); + + // c[1,0-15] + memcpy( c + ( rs_c * ( ir + 1 ) ) + ( 0*16 ), buf1, ( n0_rem * sizeof( int32_t ) ) ); + + // c[2,0-15] + memcpy( c + ( rs_c * ( ir + 2 ) ) + ( 0*16 ), buf2, ( n0_rem * sizeof( int32_t ) ) ); + + // c[3,0-15] + memcpy( c + ( rs_c * ( ir + 3 ) ) + ( 0*16 ), buf3, ( n0_rem * sizeof( int32_t ) ) ); + + // c[4,0-15] + memcpy( c + ( rs_c * ( ir + 4 ) ) + ( 0*16 ), buf4, ( n0_rem * sizeof( int32_t ) ) ); + + // c[5,0-15] + memcpy( c + ( rs_c * ( ir + 5 ) ) + ( 0*16 ), buf5, ( n0_rem * sizeof( int32_t ) ) ); + + a = a + ( MR * ps_a ); + post_op_c_i += MR; + } + + if ( m_partial_pieces > 0 ) + { + if ( m_partial_pieces == 5 ) + { + dim_t cs_a_use = ( cs_a == 4 ) ? 4 : ( ( cs_a / 6 ) * 5 ); + lpgemm_rowvar_u8s8s32o32_5xlt16 + ( + k0, + a, rs_a, cs_a_use, + b, rs_b, cs_b, + ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, + alpha, beta, n0_rem, + is_last_k, + post_op_c_i, post_op_c_j, + post_ops_list, rs_c_downscale + ); + } + else if ( m_partial_pieces == 4 ) + { + dim_t cs_a_use = ( cs_a == 4 ) ? 4 : ( ( cs_a / 6 ) * 4 ); + lpgemm_rowvar_u8s8s32o32_4xlt16 + ( + k0, + a, rs_a, cs_a_use, + b, rs_b, cs_b, + ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, + alpha, beta, n0_rem, + is_last_k, + post_op_c_i, post_op_c_j, + post_ops_list, rs_c_downscale + ); + } + else if ( m_partial_pieces == 3 ) + { + dim_t cs_a_use = ( cs_a == 4 ) ? 4 : ( ( cs_a / 6 ) * 3 ); + lpgemm_rowvar_u8s8s32o32_3xlt16 + ( + k0, + a, rs_a, cs_a_use, + b, rs_b, cs_b, + ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, + alpha, beta, n0_rem, + is_last_k, + post_op_c_i, post_op_c_j, + post_ops_list, rs_c_downscale + ); + } + else if ( m_partial_pieces == 2 ) + { + dim_t cs_a_use = ( cs_a == 4 ) ? 4 : ( ( cs_a / 6 ) * 2 ); + lpgemm_rowvar_u8s8s32o32_2xlt16 + ( + k0, + a, rs_a, cs_a_use, + b, rs_b, cs_b, + ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, + alpha, beta, n0_rem, + is_last_k, + post_op_c_i, post_op_c_j, + post_ops_list, rs_c_downscale + ); + } + else if ( m_partial_pieces == 1 ) + { + dim_t cs_a_use = ( cs_a == 4 ) ? 4 : ( ( cs_a / 6 ) * 1 ); + lpgemm_rowvar_u8s8s32o32_1xlt16 + ( + k0, + a, rs_a, cs_a_use, + b, rs_b, cs_b, + ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, + alpha, beta, n0_rem, + is_last_k, + post_op_c_i, post_op_c_j, + post_ops_list, rs_c_downscale + ); + } + } +} + +// 6x16 int8o32 fringe kernel +LPGEMM_N_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6x16) +{ + static void* post_ops_labels[] = + { + &&POST_OPS_6x16_DISABLE, + &&POST_OPS_BIAS_6x16, + &&POST_OPS_RELU_6x16, + &&POST_OPS_RELU_SCALE_6x16, + &&POST_OPS_DOWNSCALE_6x16 + }; + dim_t MR = 6; + dim_t m_full_pieces = m0 / MR; + dim_t m_full_pieces_loop_limit = m_full_pieces * MR; + dim_t m_partial_pieces = m0 % MR; + + dim_t k_full_pieces = k0 / 4; + dim_t k_partial_pieces = k0 % 4; + + uint32_t a_kfringe_buf = 0; + + // B matrix storage. + __m512i b0; + + // A matrix storage. + __m512i a_int32_0; + + for ( dim_t ir = 0; ir < m_full_pieces_loop_limit; ir += MR ) + { + // Registers to use for accumulating C. + __m512i c_int32_0p0 = _mm512_setzero_epi32(); + + __m512i c_int32_1p0 = _mm512_setzero_epi32(); + + __m512i c_int32_2p0 = _mm512_setzero_epi32(); + + __m512i c_int32_3p0 = _mm512_setzero_epi32(); + + __m512i c_int32_4p0 = _mm512_setzero_epi32(); + + __m512i c_int32_5p0 = _mm512_setzero_epi32(); + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + // Load 4 rows with 16 elements each from B to 1 ZMM registers. It + // is to be noted that the B matrix is packed for use in vnni + // instructions and each load to ZMM register will have 4 elements + // along k direction and 16 elements across n directions, so 4x16 + // elements to a ZMM register. + b0 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[0,0-15] = a[0,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); + + // Broadcast a[1,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[1,0-15] = a[1,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_0, b0 ); + + // Broadcast a[2,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[2,0-15] = a[2,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_2p0 = _mm512_dpbusd_epi32( c_int32_2p0, a_int32_0, b0 ); + + // Broadcast a[3,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 3 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[3,0-15] = a[3,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_3p0 = _mm512_dpbusd_epi32( c_int32_3p0, a_int32_0, b0 ); + + // Broadcast a[4,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 4 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[4,0-15] = a[4,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_4p0 = _mm512_dpbusd_epi32( c_int32_4p0, a_int32_0, b0 ); + + // Broadcast a[5,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 5 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[5,0-15] = a[5,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_5p0 = _mm512_dpbusd_epi32( c_int32_5p0, a_int32_0, b0 ); + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + b0 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+4]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( uint8_t ) ) + ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[0,0-15] = a[0,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); + + // Broadcast a[1,kr:kr+4]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( uint8_t ) ) + ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[1,0-15] = a[1,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_0, b0 ); + + // Broadcast a[2,kr:kr+4]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 2 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( uint8_t ) ) + ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[2,0-15] = a[2,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_2p0 = _mm512_dpbusd_epi32( c_int32_2p0, a_int32_0, b0 ); + + // Broadcast a[3,kr:kr+4]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 3 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( uint8_t ) ) + ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[3,0-15] = a[3,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_3p0 = _mm512_dpbusd_epi32( c_int32_3p0, a_int32_0, b0 ); + + // Broadcast a[4,kr:kr+4]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 4 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( uint8_t ) ) + ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[4,0-15] = a[4,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_4p0 = _mm512_dpbusd_epi32( c_int32_4p0, a_int32_0, b0 ); + + // Broadcast a[5,kr:kr+4]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 5 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( uint8_t ) ) + ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[5,0-15] = a[5,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_5p0 = _mm512_dpbusd_epi32( c_int32_5p0, a_int32_0, b0 ); + } + + // Load alpha and beta + __m512i selector1 = _mm512_set1_epi32( alpha ); + __m512i selector2 = _mm512_set1_epi32( beta ); + + // Scale by alpha + c_int32_0p0 = _mm512_mullo_epi32( selector1, c_int32_0p0 ); + + c_int32_1p0 = _mm512_mullo_epi32( selector1, c_int32_1p0 ); + + c_int32_2p0 = _mm512_mullo_epi32( selector1, c_int32_2p0 ); + + c_int32_3p0 = _mm512_mullo_epi32( selector1, c_int32_3p0 ); + + c_int32_4p0 = _mm512_mullo_epi32( selector1, c_int32_4p0 ); + + c_int32_5p0 = _mm512_mullo_epi32( selector1, c_int32_5p0 ); + + // Scale C by beta. + if ( beta != 0 ) + { + // c[0,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 0 ) ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); + + // c[1,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 1 ) ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); + + // c[2,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 2 ) ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_2p0 = _mm512_add_epi32( selector1, c_int32_2p0 ); + + // c[3,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 3 ) ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_3p0 = _mm512_add_epi32( selector1, c_int32_3p0 ); + + // c[4,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 4 ) ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_4p0 = _mm512_add_epi32( selector1, c_int32_4p0 ); + + // c[5,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 5 ) ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_5p0 = _mm512_add_epi32( selector1, c_int32_5p0 ); + } + + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_6x16: + { + selector1 = + _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + + post_op_c_j ); + + // c[0,0-15] + c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); + + // c[1,0-15] + c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); + + // c[2,0-15] + c_int32_2p0 = _mm512_add_epi32( selector1, c_int32_2p0 ); + + // c[3,0-15] + c_int32_3p0 = _mm512_add_epi32( selector1, c_int32_3p0 ); + + // c[4,0-15] + c_int32_4p0 = _mm512_add_epi32( selector1, c_int32_4p0 ); + + // c[5,0-15] + c_int32_5p0 = _mm512_add_epi32( selector1, c_int32_5p0 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_6x16: + { + selector1 = _mm512_setzero_epi32(); + + // c[0,0-15] + c_int32_0p0 = _mm512_max_epi32( selector1, c_int32_0p0 ); + + // c[1,0-15] + c_int32_1p0 = _mm512_max_epi32( selector1, c_int32_1p0 ); + + // c[2,0-15] + c_int32_2p0 = _mm512_max_epi32( selector1, c_int32_2p0 ); + + // c[3,0-15] + c_int32_3p0 = _mm512_max_epi32( selector1, c_int32_3p0 ); + + // c[4,0-15] + c_int32_4p0 = _mm512_max_epi32( selector1, c_int32_4p0 ); + + // c[5,0-15] + c_int32_5p0 = _mm512_max_epi32( selector1, c_int32_5p0 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_6x16: + { + selector1 = _mm512_setzero_epi32(); + selector2 = + _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_0p0) + + // c[1, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_1p0) + + // c[2, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_2p0) + + // c[3, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_3p0) + + // c[4, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_4p0) + + // c[5, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_5p0) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_DOWNSCALE_6x16: + { + selector1 = + _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j + ( 0 * 16 ) ); + + // c[0, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_0p0,selector1,0,0); + + // c[1, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_1p0,selector1,1,0); + + // c[2, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_2p0,selector1,2,0); + + // c[3, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_3p0,selector1,3,0); + + // c[4, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_4p0,selector1,4,0); + + // c[5, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_5p0,selector1,5,0); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_6x16_DISABLE: + ; + + // Store the results. + // c[0,0-15] + _mm512_storeu_epi32( c + ( rs_c * ( ir + 0 ) ) + ( 0*16 ), c_int32_0p0 ); + + // c[1,0-15] + _mm512_storeu_epi32( c + ( rs_c * ( ir + 1 ) ) + ( 0*16 ), c_int32_1p0 ); + + // c[2,0-15] + _mm512_storeu_epi32( c + ( rs_c * ( ir + 2 ) ) + ( 0*16 ), c_int32_2p0 ); + + // c[3,0-15] + _mm512_storeu_epi32( c + ( rs_c * ( ir + 3 ) ) + ( 0*16 ), c_int32_3p0 ); + + // c[4,0-15] + _mm512_storeu_epi32( c + ( rs_c * ( ir + 4 ) ) + ( 0*16 ), c_int32_4p0 ); + + // c[5,0-15] + _mm512_storeu_epi32( c + ( rs_c * ( ir + 5 ) ) + ( 0*16 ), c_int32_5p0 ); + + a = a + ( MR * ps_a ); + post_op_c_i += MR; + } + + if ( m_partial_pieces > 0 ) + { + if ( m_partial_pieces == 5 ) + { + dim_t cs_a_use = ( cs_a == 4 ) ? 4 : ( ( cs_a / 6 ) * 5 ); + lpgemm_rowvar_u8s8s32o32_5x16 + ( + k0, + a, rs_a, cs_a_use, + b, rs_b, cs_b, + ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, + alpha, beta, + is_last_k, + post_op_c_i, post_op_c_j, + post_ops_list, rs_c_downscale + ); + } + else if ( m_partial_pieces == 4 ) + { + dim_t cs_a_use = ( cs_a == 4 ) ? 4 : ( ( cs_a / 6 ) * 4 ); + lpgemm_rowvar_u8s8s32o32_4x16 + ( + k0, + a, rs_a, cs_a_use, + b, rs_b, cs_b, + ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, + alpha, beta, + is_last_k, + post_op_c_i, post_op_c_j, + post_ops_list, rs_c_downscale + ); + } + else if ( m_partial_pieces == 3 ) + { + dim_t cs_a_use = ( cs_a == 4 ) ? 4 : ( ( cs_a / 6 ) * 3 ); + lpgemm_rowvar_u8s8s32o32_3x16 + ( + k0, + a, rs_a, cs_a_use, + b, rs_b, cs_b, + ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, + alpha, beta, + is_last_k, + post_op_c_i, post_op_c_j, + post_ops_list, rs_c_downscale + ); + } + else if ( m_partial_pieces == 2 ) + { + dim_t cs_a_use = ( cs_a == 4 ) ? 4 : ( ( cs_a / 6 ) * 2 ); + lpgemm_rowvar_u8s8s32o32_2x16 + ( + k0, + a, rs_a, cs_a_use, + b, rs_b, cs_b, + ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, + alpha, beta, + is_last_k, + post_op_c_i, post_op_c_j, + post_ops_list, rs_c_downscale + ); + } + else if ( m_partial_pieces == 1 ) + { + dim_t cs_a_use = ( cs_a == 4 ) ? 4 : ( ( cs_a / 6 ) * 1 ); + lpgemm_rowvar_u8s8s32o32_1x16 + ( + k0, + a, rs_a, cs_a_use, + b, rs_b, cs_b, + ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, + alpha, beta, + is_last_k, + post_op_c_i, post_op_c_j, + post_ops_list, rs_c_downscale + ); + } + } +} + +// 6x32 int8o32 fringe kernel +LPGEMM_N_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6x32) +{ + static void* post_ops_labels[] = + { + &&POST_OPS_6x32_DISABLE, + &&POST_OPS_BIAS_6x32, + &&POST_OPS_RELU_6x32, + &&POST_OPS_RELU_SCALE_6x32, + &&POST_OPS_DOWNSCALE_6x32 + }; + dim_t MR = 6; + dim_t m_full_pieces = m0 / MR; + dim_t m_full_pieces_loop_limit = m_full_pieces * MR; + dim_t m_partial_pieces = m0 % MR; + + dim_t k_full_pieces = k0 / 4; + dim_t k_partial_pieces = k0 % 4; + + uint32_t a_kfringe_buf = 0; + + // B matrix storage. + __m512i b0; + __m512i b1; + + // A matrix storage. + __m512i a_int32_0; + + for ( dim_t ir = 0; ir < m_full_pieces_loop_limit; ir += MR ) + { + // Registers to use for accumulating C. + __m512i c_int32_0p0 = _mm512_setzero_epi32(); + __m512i c_int32_0p1 = _mm512_setzero_epi32(); + + __m512i c_int32_1p0 = _mm512_setzero_epi32(); + __m512i c_int32_1p1 = _mm512_setzero_epi32(); + + __m512i c_int32_2p0 = _mm512_setzero_epi32(); + __m512i c_int32_2p1 = _mm512_setzero_epi32(); + + __m512i c_int32_3p0 = _mm512_setzero_epi32(); + __m512i c_int32_3p1 = _mm512_setzero_epi32(); + + __m512i c_int32_4p0 = _mm512_setzero_epi32(); + __m512i c_int32_4p1 = _mm512_setzero_epi32(); + + __m512i c_int32_5p0 = _mm512_setzero_epi32(); + __m512i c_int32_5p1 = _mm512_setzero_epi32(); + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + // Load 4 rows with 32 elements each from B to 2 ZMM registers. It + // is to be noted that the B matrix is packed for use in vnni + // instructions and each load to ZMM register will have 4 elements + // along k direction and 16 elements across n directions, so 4x16 + // elements to a ZMM register. + b0 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + b1 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 1 ) ); + + // Broadcast a[0,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[0,0-31] = a[0,kr:kr+4]*b[kr:kr+4,0-31] + c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); + c_int32_0p1 = _mm512_dpbusd_epi32( c_int32_0p1, a_int32_0, b1 ); + + // Broadcast a[1,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[1,0-31] = a[1,kr:kr+4]*b[kr:kr+4,0-31] + c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_0, b0 ); + c_int32_1p1 = _mm512_dpbusd_epi32( c_int32_1p1, a_int32_0, b1 ); + + // Broadcast a[2,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[2,0-31] = a[2,kr:kr+4]*b[kr:kr+4,0-31] + c_int32_2p0 = _mm512_dpbusd_epi32( c_int32_2p0, a_int32_0, b0 ); + c_int32_2p1 = _mm512_dpbusd_epi32( c_int32_2p1, a_int32_0, b1 ); + + // Broadcast a[3,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 3 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[3,0-31] = a[3,kr:kr+4]*b[kr:kr+4,0-31] + c_int32_3p0 = _mm512_dpbusd_epi32( c_int32_3p0, a_int32_0, b0 ); + c_int32_3p1 = _mm512_dpbusd_epi32( c_int32_3p1, a_int32_0, b1 ); + + // Broadcast a[4,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 4 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[4,0-31] = a[4,kr:kr+4]*b[kr:kr+4,0-31] + c_int32_4p0 = _mm512_dpbusd_epi32( c_int32_4p0, a_int32_0, b0 ); + c_int32_4p1 = _mm512_dpbusd_epi32( c_int32_4p1, a_int32_0, b1 ); + + // Broadcast a[5,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 5 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[5,0-31] = a[5,kr:kr+4]*b[kr:kr+4,0-31] + c_int32_5p0 = _mm512_dpbusd_epi32( c_int32_5p0, a_int32_0, b0 ); + c_int32_5p1 = _mm512_dpbusd_epi32( c_int32_5p1, a_int32_0, b1 ); + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + b0 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + b1 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); + + // Broadcast a[0,kr:kr+4]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( uint8_t ) ) + ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[0,0-31] = a[0,kr:kr+4]*b[kr:kr+4,0-31] + c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); + c_int32_0p1 = _mm512_dpbusd_epi32( c_int32_0p1, a_int32_0, b1 ); + + // Broadcast a[1,kr:kr+4]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( uint8_t ) ) + ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[1,0-31] = a[1,kr:kr+4]*b[kr:kr+4,0-31] + c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_0, b0 ); + c_int32_1p1 = _mm512_dpbusd_epi32( c_int32_1p1, a_int32_0, b1 ); + + // Broadcast a[2,kr:kr+4]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 2 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( uint8_t ) ) + ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[2,0-31] = a[2,kr:kr+4]*b[kr:kr+4,0-31] + c_int32_2p0 = _mm512_dpbusd_epi32( c_int32_2p0, a_int32_0, b0 ); + c_int32_2p1 = _mm512_dpbusd_epi32( c_int32_2p1, a_int32_0, b1 ); + + // Broadcast a[3,kr:kr+4]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 3 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( uint8_t ) ) + ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[3,0-31] = a[3,kr:kr+4]*b[kr:kr+4,0-31] + c_int32_3p0 = _mm512_dpbusd_epi32( c_int32_3p0, a_int32_0, b0 ); + c_int32_3p1 = _mm512_dpbusd_epi32( c_int32_3p1, a_int32_0, b1 ); + + // Broadcast a[4,kr:kr+4]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 4 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( uint8_t ) ) + ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[4,0-31] = a[4,kr:kr+4]*b[kr:kr+4,0-31] + c_int32_4p0 = _mm512_dpbusd_epi32( c_int32_4p0, a_int32_0, b0 ); + c_int32_4p1 = _mm512_dpbusd_epi32( c_int32_4p1, a_int32_0, b1 ); + + // Broadcast a[5,kr:kr+4]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 5 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( uint8_t ) ) + ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[5,0-31] = a[5,kr:kr+4]*b[kr:kr+4,0-31] + c_int32_5p0 = _mm512_dpbusd_epi32( c_int32_5p0, a_int32_0, b0 ); + c_int32_5p1 = _mm512_dpbusd_epi32( c_int32_5p1, a_int32_0, b1 ); + } + + // Load alpha and beta + __m512i selector1 = _mm512_set1_epi32( alpha ); + __m512i selector2 = _mm512_set1_epi32( beta ); + + // Scale by alpha + c_int32_0p0 = _mm512_mullo_epi32( selector1, c_int32_0p0 ); + c_int32_0p1 = _mm512_mullo_epi32( selector1, c_int32_0p1 ); + + c_int32_1p0 = _mm512_mullo_epi32( selector1, c_int32_1p0 ); + c_int32_1p1 = _mm512_mullo_epi32( selector1, c_int32_1p1 ); + + c_int32_2p0 = _mm512_mullo_epi32( selector1, c_int32_2p0 ); + c_int32_2p1 = _mm512_mullo_epi32( selector1, c_int32_2p1 ); + + c_int32_3p0 = _mm512_mullo_epi32( selector1, c_int32_3p0 ); + c_int32_3p1 = _mm512_mullo_epi32( selector1, c_int32_3p1 ); + + c_int32_4p0 = _mm512_mullo_epi32( selector1, c_int32_4p0 ); + c_int32_4p1 = _mm512_mullo_epi32( selector1, c_int32_4p1 ); + + c_int32_5p0 = _mm512_mullo_epi32( selector1, c_int32_5p0 ); + c_int32_5p1 = _mm512_mullo_epi32( selector1, c_int32_5p1 ); + + // Scale C by beta. + if ( beta != 0 ) + { + // c[0,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 0 ) ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); + + // c[0, 16-31] + selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 0 ) ) + ( 1*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_0p1 = _mm512_add_epi32( selector1, c_int32_0p1 ); + + // c[1,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 1 ) ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); + + // c[1,16-31] + selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 1 ) ) + ( 1*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_1p1 = _mm512_add_epi32( selector1, c_int32_1p1 ); + + // c[2,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 2 ) ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_2p0 = _mm512_add_epi32( selector1, c_int32_2p0 ); + + // c[2,16-31] + selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 2 ) ) + ( 1*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_2p1 = _mm512_add_epi32( selector1, c_int32_2p1 ); + + // c[3,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 3 ) ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_3p0 = _mm512_add_epi32( selector1, c_int32_3p0 ); + + // c[3,16-31] + selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 3 ) ) + ( 1*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_3p1 = _mm512_add_epi32( selector1, c_int32_3p1 ); + + // c[4,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 4 ) ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_4p0 = _mm512_add_epi32( selector1, c_int32_4p0 ); + + // c[4,16-31] + selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 4 ) ) + ( 1*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_4p1 = _mm512_add_epi32( selector1, c_int32_4p1 ); + + // c[5,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 5 ) ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_5p0 = _mm512_add_epi32( selector1, c_int32_5p0 ); + + // c[5,16-31] + selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 5 ) ) + ( 1*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_5p1 = _mm512_add_epi32( selector1, c_int32_5p1 ); + } + + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_6x32: + { + selector1 = + _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 1 * 16 ) ); + + // c[0,0-15] + c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); + + // c[0, 16-31] + c_int32_0p1 = _mm512_add_epi32( selector2, c_int32_0p1 ); + + // c[1,0-15] + c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); + + // c[1, 16-31] + c_int32_1p1 = _mm512_add_epi32( selector2, c_int32_1p1 ); + + // c[2,0-15] + c_int32_2p0 = _mm512_add_epi32( selector1, c_int32_2p0 ); + + // c[2, 16-31] + c_int32_2p1 = _mm512_add_epi32( selector2, c_int32_2p1 ); + + // c[3,0-15] + c_int32_3p0 = _mm512_add_epi32( selector1, c_int32_3p0 ); + + // c[3, 16-31] + c_int32_3p1 = _mm512_add_epi32( selector2, c_int32_3p1 ); + + // c[4,0-15] + c_int32_4p0 = _mm512_add_epi32( selector1, c_int32_4p0 ); + + // c[4, 16-31] + c_int32_4p1 = _mm512_add_epi32( selector2, c_int32_4p1 ); + + // c[5,0-15] + c_int32_5p0 = _mm512_add_epi32( selector1, c_int32_5p0 ); + + // c[5, 16-31] + c_int32_5p1 = _mm512_add_epi32( selector2, c_int32_5p1 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_6x32: + { + selector1 = _mm512_setzero_epi32(); + + // c[0,0-15] + c_int32_0p0 = _mm512_max_epi32( selector1, c_int32_0p0 ); + + // c[0, 16-31] + c_int32_0p1 = _mm512_max_epi32( selector1, c_int32_0p1 ); + + // c[1,0-15] + c_int32_1p0 = _mm512_max_epi32( selector1, c_int32_1p0 ); + + // c[1,16-31] + c_int32_1p1 = _mm512_max_epi32( selector1, c_int32_1p1 ); + + // c[2,0-15] + c_int32_2p0 = _mm512_max_epi32( selector1, c_int32_2p0 ); + + // c[2,16-31] + c_int32_2p1 = _mm512_max_epi32( selector1, c_int32_2p1 ); + + // c[3,0-15] + c_int32_3p0 = _mm512_max_epi32( selector1, c_int32_3p0 ); + + // c[3,16-31] + c_int32_3p1 = _mm512_max_epi32( selector1, c_int32_3p1 ); + + // c[4,0-15] + c_int32_4p0 = _mm512_max_epi32( selector1, c_int32_4p0 ); + + // c[4,16-31] + c_int32_4p1 = _mm512_max_epi32( selector1, c_int32_4p1 ); + + // c[5,0-15] + c_int32_5p0 = _mm512_max_epi32( selector1, c_int32_5p0 ); + + // c[5,16-31] + c_int32_5p1 = _mm512_max_epi32( selector1, c_int32_5p1 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_6x32: + { + selector1 = _mm512_setzero_epi32(); + selector2 = + _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_0p0) + + // c[0, 16-31] + RELU_SCALE_OP_S32_AVX512(c_int32_0p1) + + // c[1, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_1p0) + + // c[1, 16-31] + RELU_SCALE_OP_S32_AVX512(c_int32_1p1) + + // c[2, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_2p0) + + // c[2, 16-31] + RELU_SCALE_OP_S32_AVX512(c_int32_2p1) + + // c[3, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_3p0) + + // c[3, 16-31] + RELU_SCALE_OP_S32_AVX512(c_int32_3p1) + + // c[4, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_4p0) + + // c[4, 16-31] + RELU_SCALE_OP_S32_AVX512(c_int32_4p1) + + // c[5, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_5p0) + + // c[5, 16-31] + RELU_SCALE_OP_S32_AVX512(c_int32_5p1) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_DOWNSCALE_6x32: + { + selector1 = + _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j + ( 1 * 16 ) ); + + // c[0, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_0p0,selector1,0,0); + + // c[0, 16-31] + CVT_MULRND_CVT32_CVT8(c_int32_0p1,selector2,0,1); + + // c[1, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_1p0,selector1,1,0); + + // c[1, 16-31] + CVT_MULRND_CVT32_CVT8(c_int32_1p1,selector2,1,1); + + // c[2, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_2p0,selector1,2,0); + + // c[2, 16-31] + CVT_MULRND_CVT32_CVT8(c_int32_2p1,selector2,2,1); + + // c[3, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_3p0,selector1,3,0); + + // c[3, 16-31] + CVT_MULRND_CVT32_CVT8(c_int32_3p1,selector2,3,1); + + // c[4, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_4p0,selector1,4,0); + + // c[4, 16-31] + CVT_MULRND_CVT32_CVT8(c_int32_4p1,selector2,4,1); + + // c[5, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_5p0,selector1,5,0); + + // c[5, 16-31] + CVT_MULRND_CVT32_CVT8(c_int32_5p1,selector2,5,1); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_6x32_DISABLE: + ; + + // Store the results. + // c[0,0-15] + _mm512_storeu_epi32( c + ( rs_c * ( ir + 0 ) ) + ( 0*16 ), c_int32_0p0 ); + + // c[0, 16-31] + _mm512_storeu_epi32( c + ( rs_c * ( ir + 0 ) ) + ( 1*16 ), c_int32_0p1 ); + + // c[1,0-15] + _mm512_storeu_epi32( c + ( rs_c * ( ir + 1 ) ) + ( 0*16 ), c_int32_1p0 ); + + // c[1,16-31] + _mm512_storeu_epi32( c + ( rs_c * ( ir + 1 ) ) + ( 1*16 ), c_int32_1p1 ); + + // c[2,0-15] + _mm512_storeu_epi32( c + ( rs_c * ( ir + 2 ) ) + ( 0*16 ), c_int32_2p0 ); + + // c[2,16-31] + _mm512_storeu_epi32( c + ( rs_c * ( ir + 2 ) ) + ( 1*16 ), c_int32_2p1 ); + + // c[3,0-15] + _mm512_storeu_epi32( c + ( rs_c * ( ir + 3 ) ) + ( 0*16 ), c_int32_3p0 ); + + // c[3,16-31] + _mm512_storeu_epi32( c + ( rs_c * ( ir + 3 ) ) + ( 1*16 ), c_int32_3p1 ); + + // c[4,0-15] + _mm512_storeu_epi32( c + ( rs_c * ( ir + 4 ) ) + ( 0*16 ), c_int32_4p0 ); + + // c[4,16-31] + _mm512_storeu_epi32( c + ( rs_c * ( ir + 4 ) ) + ( 1*16 ), c_int32_4p1 ); + + // c[5,0-15] + _mm512_storeu_epi32( c + ( rs_c * ( ir + 5 ) ) + ( 0*16 ), c_int32_5p0 ); + + // c[5,16-31] + _mm512_storeu_epi32( c + ( rs_c * ( ir + 5 ) ) + ( 1*16 ), c_int32_5p1 ); + + a = a + ( MR * ps_a ); + post_op_c_i += MR; + } + + if ( m_partial_pieces > 0 ) + { + if ( m_partial_pieces == 5 ) + { + dim_t cs_a_use = ( cs_a == 4 ) ? 4 : ( ( cs_a / 6 ) * 5 ); + lpgemm_rowvar_u8s8s32o32_5x32 + ( + k0, + a, rs_a, cs_a_use, + b, rs_b, cs_b, + ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, + alpha, beta, + is_last_k, + post_op_c_i, post_op_c_j, + post_ops_list, rs_c_downscale + ); + } + else if ( m_partial_pieces == 4 ) + { + dim_t cs_a_use = ( cs_a == 4 ) ? 4 : ( ( cs_a / 6 ) * 4 ); + lpgemm_rowvar_u8s8s32o32_4x32 + ( + k0, + a, rs_a, cs_a_use, + b, rs_b, cs_b, + ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, + alpha, beta, + is_last_k, + post_op_c_i, post_op_c_j, + post_ops_list, rs_c_downscale + ); + } + else if ( m_partial_pieces == 3 ) + { + dim_t cs_a_use = ( cs_a == 4 ) ? 4 : ( ( cs_a / 6 ) * 3 ); + lpgemm_rowvar_u8s8s32o32_3x32 + ( + k0, + a, rs_a, cs_a_use, + b, rs_b, cs_b, + ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, + alpha, beta, + is_last_k, + post_op_c_i, post_op_c_j, + post_ops_list, rs_c_downscale + ); + } + else if ( m_partial_pieces == 2 ) + { + dim_t cs_a_use = ( cs_a == 4 ) ? 4 : ( ( cs_a / 6 ) * 2 ); + lpgemm_rowvar_u8s8s32o32_2x32 + ( + k0, + a, rs_a, cs_a_use, + b, rs_b, cs_b, + ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, + alpha, beta, + is_last_k, + post_op_c_i, post_op_c_j, + post_ops_list, rs_c_downscale + ); + } + else if ( m_partial_pieces == 1 ) + { + dim_t cs_a_use = ( cs_a == 4 ) ? 4 : ( ( cs_a / 6 ) * 1 ); + lpgemm_rowvar_u8s8s32o32_1x32 + ( + k0, + a, rs_a, cs_a_use, + b, rs_b, cs_b, + ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, + alpha, beta, + is_last_k, + post_op_c_i, post_op_c_j, + post_ops_list, rs_c_downscale + ); + } + } +} + +// 6x48 int8o32 fringe kernel +LPGEMM_N_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6x48) +{ + static void* post_ops_labels[] = + { + &&POST_OPS_6x48_DISABLE, + &&POST_OPS_BIAS_6x48, + &&POST_OPS_RELU_6x48, + &&POST_OPS_RELU_SCALE_6x48, + &&POST_OPS_DOWNSCALE_6x48 + }; + dim_t MR = 6; + dim_t m_full_pieces = m0 / MR; + dim_t m_full_pieces_loop_limit = m_full_pieces * MR; + dim_t m_partial_pieces = m0 % MR; + + dim_t k_full_pieces = k0 / 4; + dim_t k_partial_pieces = k0 % 4; + + uint32_t a_kfringe_buf = 0; + + // B matrix storage. + __m512i b0; + __m512i b1; + __m512i b2; + + // A matrix storage. + __m512i a_int32_0; + + for ( dim_t ir = 0; ir < m_full_pieces_loop_limit; ir += MR ) + { + // Registers to use for accumulating C. + __m512i c_int32_0p0 = _mm512_setzero_epi32(); + __m512i c_int32_0p1 = _mm512_setzero_epi32(); + __m512i c_int32_0p2 = _mm512_setzero_epi32(); + + __m512i c_int32_1p0 = _mm512_setzero_epi32(); + __m512i c_int32_1p1 = _mm512_setzero_epi32(); + __m512i c_int32_1p2 = _mm512_setzero_epi32(); + + __m512i c_int32_2p0 = _mm512_setzero_epi32(); + __m512i c_int32_2p1 = _mm512_setzero_epi32(); + __m512i c_int32_2p2 = _mm512_setzero_epi32(); + + __m512i c_int32_3p0 = _mm512_setzero_epi32(); + __m512i c_int32_3p1 = _mm512_setzero_epi32(); + __m512i c_int32_3p2 = _mm512_setzero_epi32(); + + __m512i c_int32_4p0 = _mm512_setzero_epi32(); + __m512i c_int32_4p1 = _mm512_setzero_epi32(); + __m512i c_int32_4p2 = _mm512_setzero_epi32(); + + __m512i c_int32_5p0 = _mm512_setzero_epi32(); + __m512i c_int32_5p1 = _mm512_setzero_epi32(); + __m512i c_int32_5p2 = _mm512_setzero_epi32(); + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + // Load 4 rows with 48 elements each from B to 3 ZMM registers. It + // is to be noted that the B matrix is packed for use in vnni + // instructions and each load to ZMM register will have 4 elements + // along k direction and 16 elements across n directions, so 4x16 + // elements to a ZMM register. + b0 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + b1 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 1 ) ); + b2 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 2 ) ); + + // Broadcast a[0,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[0,0-47] = a[0,kr:kr+4]*b[kr:kr+4,0-47] + c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); + c_int32_0p1 = _mm512_dpbusd_epi32( c_int32_0p1, a_int32_0, b1 ); + c_int32_0p2 = _mm512_dpbusd_epi32( c_int32_0p2, a_int32_0, b2 ); + + // Broadcast a[1,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[1,0-47] = a[1,kr:kr+4]*b[kr:kr+4,0-47] + c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_0, b0 ); + c_int32_1p1 = _mm512_dpbusd_epi32( c_int32_1p1, a_int32_0, b1 ); + c_int32_1p2 = _mm512_dpbusd_epi32( c_int32_1p2, a_int32_0, b2 ); + + // Broadcast a[2,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[2,0-47] = a[2,kr:kr+4]*b[kr:kr+4,0-47] + c_int32_2p0 = _mm512_dpbusd_epi32( c_int32_2p0, a_int32_0, b0 ); + c_int32_2p1 = _mm512_dpbusd_epi32( c_int32_2p1, a_int32_0, b1 ); + c_int32_2p2 = _mm512_dpbusd_epi32( c_int32_2p2, a_int32_0, b2 ); + + // Broadcast a[3,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 3 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[3,0-47] = a[3,kr:kr+4]*b[kr:kr+4,0-47] + c_int32_3p0 = _mm512_dpbusd_epi32( c_int32_3p0, a_int32_0, b0 ); + c_int32_3p1 = _mm512_dpbusd_epi32( c_int32_3p1, a_int32_0, b1 ); + c_int32_3p2 = _mm512_dpbusd_epi32( c_int32_3p2, a_int32_0, b2 ); + + // Broadcast a[4,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 4 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[4,0-47] = a[4,kr:kr+4]*b[kr:kr+4,0-47] + c_int32_4p0 = _mm512_dpbusd_epi32( c_int32_4p0, a_int32_0, b0 ); + c_int32_4p1 = _mm512_dpbusd_epi32( c_int32_4p1, a_int32_0, b1 ); + c_int32_4p2 = _mm512_dpbusd_epi32( c_int32_4p2, a_int32_0, b2 ); + + // Broadcast a[5,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 5 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[5,0-47] = a[5,kr:kr+4]*b[kr:kr+4,0-47] + c_int32_5p0 = _mm512_dpbusd_epi32( c_int32_5p0, a_int32_0, b0 ); + c_int32_5p1 = _mm512_dpbusd_epi32( c_int32_5p1, a_int32_0, b1 ); + c_int32_5p2 = _mm512_dpbusd_epi32( c_int32_5p2, a_int32_0, b2 ); + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + b0 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + b1 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); + b2 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 2 ) ); + + // Broadcast a[0,kr:kr+4]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( uint8_t ) ) + ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[0,0-47] = a[0,kr:kr+4]*b[kr:kr+4,0-47] + c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); + c_int32_0p1 = _mm512_dpbusd_epi32( c_int32_0p1, a_int32_0, b1 ); + c_int32_0p2 = _mm512_dpbusd_epi32( c_int32_0p2, a_int32_0, b2 ); + + // Broadcast a[1,kr:kr+4]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( uint8_t ) ) + ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[1,0-47] = a[1,kr:kr+4]*b[kr:kr+4,0-47] + c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_0, b0 ); + c_int32_1p1 = _mm512_dpbusd_epi32( c_int32_1p1, a_int32_0, b1 ); + c_int32_1p2 = _mm512_dpbusd_epi32( c_int32_1p2, a_int32_0, b2 ); + + // Broadcast a[2,kr:kr+4]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 2 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( uint8_t ) ) + ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[2,0-47] = a[2,kr:kr+4]*b[kr:kr+4,0-47] + c_int32_2p0 = _mm512_dpbusd_epi32( c_int32_2p0, a_int32_0, b0 ); + c_int32_2p1 = _mm512_dpbusd_epi32( c_int32_2p1, a_int32_0, b1 ); + c_int32_2p2 = _mm512_dpbusd_epi32( c_int32_2p2, a_int32_0, b2 ); + + // Broadcast a[3,kr:kr+4]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 3 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( uint8_t ) ) + ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[3,0-47] = a[3,kr:kr+4]*b[kr:kr+4,0-47] + c_int32_3p0 = _mm512_dpbusd_epi32( c_int32_3p0, a_int32_0, b0 ); + c_int32_3p1 = _mm512_dpbusd_epi32( c_int32_3p1, a_int32_0, b1 ); + c_int32_3p2 = _mm512_dpbusd_epi32( c_int32_3p2, a_int32_0, b2 ); + + // Broadcast a[4,kr:kr+4]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 4 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( uint8_t ) ) + ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[4,0-47] = a[4,kr:kr+4]*b[kr:kr+4,0-47] + c_int32_4p0 = _mm512_dpbusd_epi32( c_int32_4p0, a_int32_0, b0 ); + c_int32_4p1 = _mm512_dpbusd_epi32( c_int32_4p1, a_int32_0, b1 ); + c_int32_4p2 = _mm512_dpbusd_epi32( c_int32_4p2, a_int32_0, b2 ); + + // Broadcast a[5,kr:kr+4]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 5 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( uint8_t ) ) + ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[5,0-47] = a[5,kr:kr+4]*b[kr:kr+4,0-47] + c_int32_5p0 = _mm512_dpbusd_epi32( c_int32_5p0, a_int32_0, b0 ); + c_int32_5p1 = _mm512_dpbusd_epi32( c_int32_5p1, a_int32_0, b1 ); + c_int32_5p2 = _mm512_dpbusd_epi32( c_int32_5p2, a_int32_0, b2 ); + } + + // Load alpha and beta + __m512i selector1 = _mm512_set1_epi32( alpha ); + __m512i selector2 = _mm512_set1_epi32( beta ); + + // Scale by alpha + c_int32_0p0 = _mm512_mullo_epi32( selector1, c_int32_0p0 ); + c_int32_0p1 = _mm512_mullo_epi32( selector1, c_int32_0p1 ); + c_int32_0p2 = _mm512_mullo_epi32( selector1, c_int32_0p2 ); + + c_int32_1p0 = _mm512_mullo_epi32( selector1, c_int32_1p0 ); + c_int32_1p1 = _mm512_mullo_epi32( selector1, c_int32_1p1 ); + c_int32_1p2 = _mm512_mullo_epi32( selector1, c_int32_1p2 ); + + c_int32_2p0 = _mm512_mullo_epi32( selector1, c_int32_2p0 ); + c_int32_2p1 = _mm512_mullo_epi32( selector1, c_int32_2p1 ); + c_int32_2p2 = _mm512_mullo_epi32( selector1, c_int32_2p2 ); + + c_int32_3p0 = _mm512_mullo_epi32( selector1, c_int32_3p0 ); + c_int32_3p1 = _mm512_mullo_epi32( selector1, c_int32_3p1 ); + c_int32_3p2 = _mm512_mullo_epi32( selector1, c_int32_3p2 ); + + c_int32_4p0 = _mm512_mullo_epi32( selector1, c_int32_4p0 ); + c_int32_4p1 = _mm512_mullo_epi32( selector1, c_int32_4p1 ); + c_int32_4p2 = _mm512_mullo_epi32( selector1, c_int32_4p2 ); + + c_int32_5p0 = _mm512_mullo_epi32( selector1, c_int32_5p0 ); + c_int32_5p1 = _mm512_mullo_epi32( selector1, c_int32_5p1 ); + c_int32_5p2 = _mm512_mullo_epi32( selector1, c_int32_5p2 ); + + // Scale C by beta. + if ( beta != 0 ) + { + // c[0,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 0 ) ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); + + // c[0, 16-31] + selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 0 ) ) + ( 1*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_0p1 = _mm512_add_epi32( selector1, c_int32_0p1 ); + + // c[0,32-47] + selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 0 ) ) + ( 2*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_0p2 = _mm512_add_epi32( selector1, c_int32_0p2 ); + + // c[1,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 1 ) ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); + + // c[1,16-31] + selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 1 ) ) + ( 1*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_1p1 = _mm512_add_epi32( selector1, c_int32_1p1 ); + + // c[1,32-47] + selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 1 ) ) + ( 2*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_1p2 = _mm512_add_epi32( selector1, c_int32_1p2 ); + + // c[2,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 2 ) ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_2p0 = _mm512_add_epi32( selector1, c_int32_2p0 ); + + // c[2,16-31] + selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 2 ) ) + ( 1*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_2p1 = _mm512_add_epi32( selector1, c_int32_2p1 ); + + // c[2,32-47] + selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 2 ) ) + ( 2*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_2p2 = _mm512_add_epi32( selector1, c_int32_2p2 ); + + // c[3,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 3 ) ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_3p0 = _mm512_add_epi32( selector1, c_int32_3p0 ); + + // c[3,16-31] + selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 3 ) ) + ( 1*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_3p1 = _mm512_add_epi32( selector1, c_int32_3p1 ); + + // c[3,32-47] + selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 3 ) ) + ( 2*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_3p2 = _mm512_add_epi32( selector1, c_int32_3p2 ); + + // c[4,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 4 ) ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_4p0 = _mm512_add_epi32( selector1, c_int32_4p0 ); + + // c[4,16-31] + selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 4 ) ) + ( 1*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_4p1 = _mm512_add_epi32( selector1, c_int32_4p1 ); + + // c[4,32-47] + selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 4 ) ) + ( 2*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_4p2 = _mm512_add_epi32( selector1, c_int32_4p2 ); + + // c[5,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 5 ) ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_5p0 = _mm512_add_epi32( selector1, c_int32_5p0 ); + + // c[5,16-31] + selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 5 ) ) + ( 1*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_5p1 = _mm512_add_epi32( selector1, c_int32_5p1 ); + + // c[5,32-47] + selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 5 ) ) + ( 2*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_5p2 = _mm512_add_epi32( selector1, c_int32_5p2 ); + } + + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_6x48: + { + selector1 = + _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 1 * 16 ) ); + a_int32_0 = + _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 2 * 16 ) ); + + // c[0,0-15] + c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); + + // c[0, 16-31] + c_int32_0p1 = _mm512_add_epi32( selector2, c_int32_0p1 ); + + // c[0,32-47] + c_int32_0p2 = _mm512_add_epi32( a_int32_0, c_int32_0p2 ); + + // c[1,0-15] + c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); + + // c[1, 16-31] + c_int32_1p1 = _mm512_add_epi32( selector2, c_int32_1p1 ); + + // c[1,32-47] + c_int32_1p2 = _mm512_add_epi32( a_int32_0, c_int32_1p2 ); + + // c[2,0-15] + c_int32_2p0 = _mm512_add_epi32( selector1, c_int32_2p0 ); + + // c[2, 16-31] + c_int32_2p1 = _mm512_add_epi32( selector2, c_int32_2p1 ); + + // c[2,32-47] + c_int32_2p2 = _mm512_add_epi32( a_int32_0, c_int32_2p2 ); + + // c[3,0-15] + c_int32_3p0 = _mm512_add_epi32( selector1, c_int32_3p0 ); + + // c[3, 16-31] + c_int32_3p1 = _mm512_add_epi32( selector2, c_int32_3p1 ); + + // c[3,32-47] + c_int32_3p2 = _mm512_add_epi32( a_int32_0, c_int32_3p2 ); + + // c[4,0-15] + c_int32_4p0 = _mm512_add_epi32( selector1, c_int32_4p0 ); + + // c[4, 16-31] + c_int32_4p1 = _mm512_add_epi32( selector2, c_int32_4p1 ); + + // c[4,32-47] + c_int32_4p2 = _mm512_add_epi32( a_int32_0, c_int32_4p2 ); + + // c[5,0-15] + c_int32_5p0 = _mm512_add_epi32( selector1, c_int32_5p0 ); + + // c[5, 16-31] + c_int32_5p1 = _mm512_add_epi32( selector2, c_int32_5p1 ); + + // c[5,32-47] + c_int32_5p2 = _mm512_add_epi32( a_int32_0, c_int32_5p2 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_6x48: + { + //printf("relu\n"); + selector1 = _mm512_setzero_epi32(); + + // c[0,0-15] + c_int32_0p0 = _mm512_max_epi32( selector1, c_int32_0p0 ); + + // c[0, 16-31] + c_int32_0p1 = _mm512_max_epi32( selector1, c_int32_0p1 ); + + // c[0,32-47] + c_int32_0p2 = _mm512_max_epi32( selector1, c_int32_0p2 ); + + // c[1,0-15] + c_int32_1p0 = _mm512_max_epi32( selector1, c_int32_1p0 ); + + // c[1,16-31] + c_int32_1p1 = _mm512_max_epi32( selector1, c_int32_1p1 ); + + // c[1,32-47] + c_int32_1p2 = _mm512_max_epi32( selector1, c_int32_1p2 ); + + // c[2,0-15] + c_int32_2p0 = _mm512_max_epi32( selector1, c_int32_2p0 ); + + // c[2,16-31] + c_int32_2p1 = _mm512_max_epi32( selector1, c_int32_2p1 ); + + // c[2,32-47] + c_int32_2p2 = _mm512_max_epi32( selector1, c_int32_2p2 ); + + // c[3,0-15] + c_int32_3p0 = _mm512_max_epi32( selector1, c_int32_3p0 ); + + // c[3,16-31] + c_int32_3p1 = _mm512_max_epi32( selector1, c_int32_3p1 ); + + // c[3,32-47] + c_int32_3p2 = _mm512_max_epi32( selector1, c_int32_3p2 ); + + // c[4,0-15] + c_int32_4p0 = _mm512_max_epi32( selector1, c_int32_4p0 ); + + // c[4,16-31] + c_int32_4p1 = _mm512_max_epi32( selector1, c_int32_4p1 ); + + // c[4,32-47] + c_int32_4p2 = _mm512_max_epi32( selector1, c_int32_4p2 ); + + // c[5,0-15] + c_int32_5p0 = _mm512_max_epi32( selector1, c_int32_5p0 ); + + // c[5,16-31] + c_int32_5p1 = _mm512_max_epi32( selector1, c_int32_5p1 ); + + // c[5,32-47] + c_int32_5p2 = _mm512_max_epi32( selector1, c_int32_5p2 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_6x48: + { + selector1 = _mm512_setzero_epi32(); + selector2 = + _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_0p0) + + // c[0, 16-31] + RELU_SCALE_OP_S32_AVX512(c_int32_0p1) + + // c[0, 32-47] + RELU_SCALE_OP_S32_AVX512(c_int32_0p2) + + // c[1, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_1p0) + + // c[1, 16-31] + RELU_SCALE_OP_S32_AVX512(c_int32_1p1) + + // c[1, 32-47] + RELU_SCALE_OP_S32_AVX512(c_int32_1p2) + + // c[2, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_2p0) + + // c[2, 16-31] + RELU_SCALE_OP_S32_AVX512(c_int32_2p1) + + // c[2, 32-47] + RELU_SCALE_OP_S32_AVX512(c_int32_2p2) + + // c[3, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_3p0) + + // c[3, 16-31] + RELU_SCALE_OP_S32_AVX512(c_int32_3p1) + + // c[3, 32-47] + RELU_SCALE_OP_S32_AVX512(c_int32_3p2) + + // c[4, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_4p0) + + // c[4, 16-31] + RELU_SCALE_OP_S32_AVX512(c_int32_4p1) + + // c[4, 32-47] + RELU_SCALE_OP_S32_AVX512(c_int32_4p2) + + // c[5, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_5p0) + + // c[5, 16-31] + RELU_SCALE_OP_S32_AVX512(c_int32_5p1) + + // c[5, 32-47] + RELU_SCALE_OP_S32_AVX512(c_int32_5p2) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_DOWNSCALE_6x48: + { + selector1 = + _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j + ( 1 * 16 ) ); + a_int32_0 = + _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j + ( 2 * 16 ) ); + + // c[0, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_0p0,selector1,0,0); + + // c[0, 16-31] + CVT_MULRND_CVT32_CVT8(c_int32_0p1,selector2,0,1); + + // c[0, 32-47] + CVT_MULRND_CVT32_CVT8(c_int32_0p2,a_int32_0,0,2); + + // c[1, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_1p0,selector1,1,0); + + // c[1, 16-31] + CVT_MULRND_CVT32_CVT8(c_int32_1p1,selector2,1,1); + + // c[1, 32-47] + CVT_MULRND_CVT32_CVT8(c_int32_1p2,a_int32_0,1,2); + + // c[2, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_2p0,selector1,2,0); + + // c[2, 16-31] + CVT_MULRND_CVT32_CVT8(c_int32_2p1,selector2,2,1); + + // c[2, 32-47] + CVT_MULRND_CVT32_CVT8(c_int32_2p2,a_int32_0,2,2); + + // c[3, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_3p0,selector1,3,0); + + // c[3, 16-31] + CVT_MULRND_CVT32_CVT8(c_int32_3p1,selector2,3,1); + + // c[3, 32-47] + CVT_MULRND_CVT32_CVT8(c_int32_3p2,a_int32_0,3,2); + + // c[4, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_4p0,selector1,4,0); + + // c[4, 16-31] + CVT_MULRND_CVT32_CVT8(c_int32_4p1,selector2,4,1); + + // c[4, 32-47] + CVT_MULRND_CVT32_CVT8(c_int32_4p2,a_int32_0,4,2); + + // c[5, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_5p0,selector1,5,0); + + // c[5, 16-31] + CVT_MULRND_CVT32_CVT8(c_int32_5p1,selector2,5,1); + + // c[5, 32-47] + CVT_MULRND_CVT32_CVT8(c_int32_5p2,a_int32_0,5,2); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_6x48_DISABLE: + ; + + // Store the results. + // c[0,0-15] + _mm512_storeu_epi32( c + ( rs_c * ( ir + 0 ) ) + ( 0*16 ), c_int32_0p0 ); + + // c[0, 16-31] + _mm512_storeu_epi32( c + ( rs_c * ( ir + 0 ) ) + ( 1*16 ), c_int32_0p1 ); + + // c[0,32-47] + _mm512_storeu_epi32( c + ( rs_c * ( ir + 0 ) ) + ( 2*16 ), c_int32_0p2 ); + + // c[1,0-15] + _mm512_storeu_epi32( c + ( rs_c * ( ir + 1 ) ) + ( 0*16 ), c_int32_1p0 ); + + // c[1,16-31] + _mm512_storeu_epi32( c + ( rs_c * ( ir + 1 ) ) + ( 1*16 ), c_int32_1p1 ); + + // c[1,32-47] + _mm512_storeu_epi32( c + ( rs_c * ( ir + 1 ) ) + ( 2*16 ), c_int32_1p2 ); + + // c[2,0-15] + _mm512_storeu_epi32( c + ( rs_c * ( ir + 2 ) ) + ( 0*16 ), c_int32_2p0 ); + + // c[2,16-31] + _mm512_storeu_epi32( c + ( rs_c * ( ir + 2 ) ) + ( 1*16 ), c_int32_2p1 ); + + // c[2,32-47] + _mm512_storeu_epi32( c + ( rs_c * ( ir + 2 ) ) + ( 2*16 ), c_int32_2p2 ); + + // c[3,0-15] + _mm512_storeu_epi32( c + ( rs_c * ( ir + 3 ) ) + ( 0*16 ), c_int32_3p0 ); + + // c[3,16-31] + _mm512_storeu_epi32( c + ( rs_c * ( ir + 3 ) ) + ( 1*16 ), c_int32_3p1 ); + + // c[3,32-47] + _mm512_storeu_epi32( c + ( rs_c * ( ir + 3 ) ) + ( 2*16 ), c_int32_3p2 ); + + // c[4,0-15] + _mm512_storeu_epi32( c + ( rs_c * ( ir + 4 ) ) + ( 0*16 ), c_int32_4p0 ); + + // c[4,16-31] + _mm512_storeu_epi32( c + ( rs_c * ( ir + 4 ) ) + ( 1*16 ), c_int32_4p1 ); + + // c[4,32-47] + _mm512_storeu_epi32( c + ( rs_c * ( ir + 4 ) ) + ( 2*16 ), c_int32_4p2 ); + + // c[5,0-15] + _mm512_storeu_epi32( c + ( rs_c * ( ir + 5 ) ) + ( 0*16 ), c_int32_5p0 ); + + // c[5,16-31] + _mm512_storeu_epi32( c + ( rs_c * ( ir + 5 ) ) + ( 1*16 ), c_int32_5p1 ); + + // c[5,32-47] + _mm512_storeu_epi32( c + ( rs_c * ( ir + 5 ) ) + ( 2*16 ), c_int32_5p2 ); + + a = a + ( MR * ps_a ); + post_op_c_i += MR; + } + + if ( m_partial_pieces > 0 ) + { + if ( m_partial_pieces == 5 ) + { + dim_t cs_a_use = ( cs_a == 4 ) ? 4 : ( ( cs_a / 6 ) * 5 ); + lpgemm_rowvar_u8s8s32o32_5x48 + ( + k0, + a, rs_a, cs_a_use, + b, rs_b, cs_b, + ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, + alpha, beta, + is_last_k, + post_op_c_i, post_op_c_j, + post_ops_list, rs_c_downscale + ); + } + else if ( m_partial_pieces == 4 ) + { + dim_t cs_a_use = ( cs_a == 4 ) ? 4 : ( ( cs_a / 6 ) * 4 ); + lpgemm_rowvar_u8s8s32o32_4x48 + ( + k0, + a, rs_a, cs_a_use, + b, rs_b, cs_b, + ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, + alpha, beta, + is_last_k, + post_op_c_i, post_op_c_j, + post_ops_list, rs_c_downscale + ); + } + else if ( m_partial_pieces == 3 ) + { + dim_t cs_a_use = ( cs_a == 4 ) ? 4 : ( ( cs_a / 6 ) * 3 ); + lpgemm_rowvar_u8s8s32o32_3x48 + ( + k0, + a, rs_a, cs_a_use, + b, rs_b, cs_b, + ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, + alpha, beta, + is_last_k, + post_op_c_i, post_op_c_j, + post_ops_list, rs_c_downscale + ); + } + else if ( m_partial_pieces == 2 ) + { + dim_t cs_a_use = ( cs_a == 4 ) ? 4 : ( ( cs_a / 6 ) * 2 ); + lpgemm_rowvar_u8s8s32o32_2x48 + ( + k0, + a, rs_a, cs_a_use, + b, rs_b, cs_b, + ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, + alpha, beta, + is_last_k, + post_op_c_i, post_op_c_j, + post_ops_list, rs_c_downscale + ); + } + else if ( m_partial_pieces == 1 ) + { + dim_t cs_a_use = ( cs_a == 4 ) ? 4 : ( ( cs_a / 6 ) * 1 ); + lpgemm_rowvar_u8s8s32o32_1x48 + ( + k0, + a, rs_a, cs_a_use, + b, rs_b, cs_b, + ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, + alpha, beta, + is_last_k, + post_op_c_i, post_op_c_j, + post_ops_list, rs_c_downscale + ); + } + } +} +#endif diff --git a/addon/aocl_gemm/kernels/u8s8s32/lpgemm_packa.h b/addon/aocl_gemm/kernels/u8s8s32/lpgemm_packa.h new file mode 100644 index 0000000000..b983b0c617 --- /dev/null +++ b/addon/aocl_gemm/kernels/u8s8s32/lpgemm_packa.h @@ -0,0 +1,55 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef BLIS_GEMM_INT8_PACKA +#define BLIS_GEMM_INT8_PACKA + +void get_packa_k64_u8s8s32o32_strides + ( + dim_t* rs_a, + dim_t* cs_a + ); + +void packa_k64_u8s8s32o32 + ( + uint8_t* pack_a_buffer_u8s8s32o32, + const uint8_t* a, + const dim_t lda, + const dim_t MC, + const dim_t KC, + dim_t* rs_a, + dim_t* cs_a + ); + +#endif //BLIS_GEMM_INT8_PACKA diff --git a/addon/aocl_gemm/kernels/u8s8s32/lpgemm_packa_amd512vnni.c b/addon/aocl_gemm/kernels/u8s8s32/lpgemm_packa_amd512vnni.c new file mode 100644 index 0000000000..601b8a3eff --- /dev/null +++ b/addon/aocl_gemm/kernels/u8s8s32/lpgemm_packa_amd512vnni.c @@ -0,0 +1,520 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include + +#include "blis.h" +#include "lpgemm_packa.h" + +#define MR 6 +#define NR 64 + +void get_packa_k64_u8s8s32o32_strides + ( + dim_t* rs_a, + dim_t* cs_a + ) +{ + *rs_a = 4; + *cs_a = 24; +} + +#ifdef BLIS_KERNELS_ZEN4 +void packa_m5_k64_u8s8s32o32 + ( + uint8_t* pack_a_buffer_u8s8s32o32, + const uint8_t* a, + const dim_t lda, + const dim_t KC + ); + +void packa_m4_k64_u8s8s32o32 + ( + uint8_t* pack_a_buffer_u8s8s32o32, + const uint8_t* a, + const dim_t lda, + const dim_t KC + ); + +void packa_m3_k64_u8s8s32o32 + ( + uint8_t* pack_a_buffer_u8s8s32o32, + const uint8_t* a, + const dim_t lda, + const dim_t KC + ); + +void packa_m2_k64_u8s8s32o32 + ( + uint8_t* pack_a_buffer_u8s8s32o32, + const uint8_t* a, + const dim_t lda, + const dim_t KC + ); + +void packa_m1_k64_u8s8s32o32 + ( + uint8_t* pack_a_buffer_u8s8s32o32, + const uint8_t* a, + const dim_t lda, + const dim_t KC + ); + +// TODO: k fringe till k=4, k%4=0 and padding to make k'%4 = 0 if k%4 != 0 originally. +void packa_k64_u8s8s32o32 + ( + uint8_t* pack_a_buffer_u8s8s32o32, + const uint8_t* a, + const dim_t lda, + const dim_t MC, + const dim_t KC, + dim_t* rs_a, + dim_t* cs_a + ) +{ + // Used for permuting the mm512i elements for use in vpdpbusd instruction. + // These are indexes of the format a0-a1-b0-b1-a2-a3-b2-b3 and a0-a1-a2-a3-b0-b1-b2-b3. + // Adding 4 int32 wise gives format a4-a5-b4-b5-a6-a7-b6-b7 and a4-a5-a6-a7-b4-b5-b6-b7. + __m512i selector1 = _mm512_setr_epi64( 0x0, 0x1, 0x8, 0x9, 0x2, 0x3, 0xA, 0xB ); + __m512i selector1_1 = _mm512_setr_epi64( 0x4, 0x5, 0xC, 0xD, 0x6, 0x7, 0xE, 0xF ); + __m512i selector2 = _mm512_setr_epi64( 0x0, 0x1, 0x2, 0x3, 0x8, 0x9, 0xA, 0xB ); + __m512i selector2_1 = _mm512_setr_epi64( 0x4, 0x5, 0x6, 0x7, 0xC, 0xD, 0xE, 0xF ); + + // First half. + __m512i selector3 = _mm512_setr_epi64( 0x0, 0x1, 0x8, 0x2, 0x3, 0x9, 0x4, 0x5 ); // 64 elems + __m512i selector4 = _mm512_setr_epi64( 0x8, 0x6, 0x7, 0x9, 0x0, 0x0, 0x0, 0x0 ); // 32 elems + __m512i selector5 = _mm512_setr_epi64( 0x0, 0x1, 0xA, 0x2, 0x3, 0xB, 0x4, 0x5 ); // 64 elems + __m512i selector6 = _mm512_setr_epi64( 0xA, 0x6, 0x7, 0xB, 0x0, 0x0, 0x0, 0x0 ); // 32 elems + + // Second half. + __m512i selector7 = _mm512_setr_epi64( 0x0, 0x1, 0xC, 0x2, 0x3, 0xD, 0x4, 0x5 ); // 64 elems + __m512i selector8 = _mm512_setr_epi64( 0xC, 0x6, 0x7, 0xD, 0x0, 0x0, 0x0, 0x0 ); // 32 elems + __m512i selector9 = _mm512_setr_epi64( 0x0, 0x1, 0xE, 0x2, 0x3, 0xF, 0x4, 0x5 ); // 64 elems + __m512i selector10 = _mm512_setr_epi64( 0xE, 0x6, 0x7, 0xF, 0x0, 0x0, 0x0, 0x0 ); // 32 elems + + dim_t m_full_pieces = MC / MR; + dim_t m_full_pieces_loop_limit = m_full_pieces * MR; + dim_t m_partial_pieces = MC % MR; + + __m512i a0; + __m512i b0; + __m512i c0; + __m512i d0; + __m512i e0; + __m512i f0; + __m512i a01; + __m512i c01; + __m512i e01; + __m256i last_piece; + + for ( dim_t ic = 0; ic < m_full_pieces_loop_limit; ic += MR ) + { + for ( dim_t kr = 0; kr < KC; kr += NR ) + { + // Rearrange for vpdpbusd, read 6 rows from A with 64 elements in each row. + a0 = _mm512_loadu_epi8( a + ( lda * ( ic + 0 ) ) + kr ); + b0 = _mm512_loadu_epi8( a + ( lda * ( ic + 1 ) ) + kr ); + c0 = _mm512_loadu_epi8( a + ( lda * ( ic + 2 ) ) + kr ); + d0 = _mm512_loadu_epi8( a + ( lda * ( ic + 3 ) ) + kr ); + e0 = _mm512_loadu_epi8( a + ( lda * ( ic + 4 ) ) + kr ); + f0 = _mm512_loadu_epi8( a + ( lda * ( ic + 5 ) ) + kr ); + + a01 = _mm512_unpacklo_epi32( a0, b0 ); + a0 = _mm512_unpackhi_epi32( a0, b0 ); + + c01 = _mm512_unpacklo_epi32( c0, d0 ); + c0 = _mm512_unpackhi_epi32( c0, d0 ); + + e01 = _mm512_unpacklo_epi32( e0, f0 ); // Elem 4 + e0 = _mm512_unpackhi_epi32( e0, f0 ); // Elem 5 + + b0 = _mm512_unpacklo_epi64( a01, c01 ); + a01 = _mm512_unpackhi_epi64( a01, c01 ); + + d0 = _mm512_unpacklo_epi64( a0, c0 ); + c01 = _mm512_unpackhi_epi64( a0, c0 ); + + a0 = _mm512_permutex2var_epi64( b0, selector1, a01 ); + c0 = _mm512_permutex2var_epi64( d0, selector1, c01 ); + b0 = _mm512_permutex2var_epi64( b0, selector1_1, a01 ); + d0 = _mm512_permutex2var_epi64( d0, selector1_1, c01 ); + + a01 = _mm512_permutex2var_epi64( a0, selector2, c0 ); // a[0] + c01 = _mm512_permutex2var_epi64( b0, selector2, d0 ); // a[2] + a0 = _mm512_permutex2var_epi64( a0, selector2_1, c0 ); // a[1] + c0 = _mm512_permutex2var_epi64( b0, selector2_1, d0 ); // a[3] + + // First half + b0 = _mm512_permutex2var_epi64( a01, selector3, e01 ); // 1st 64 + a01 = _mm512_permutex2var_epi64( a01, selector4, e0 ); // 1st 32 + d0 = _mm512_permutex2var_epi64( a0, selector5, e01 ); // 2nd 64 + a0 = _mm512_permutex2var_epi64( a0, selector6, e0 ); // 2nd 32 + + _mm512_storeu_epi64( pack_a_buffer_u8s8s32o32 + ( ( ic * KC ) + ( ( kr * MR ) + ( 0 ) ) ), b0 ); + _mm512_storeu_epi64( pack_a_buffer_u8s8s32o32 + ( ( ic * KC ) + ( ( kr * MR ) + ( 64 ) ) ) , a01 ); + _mm512_storeu_epi64( pack_a_buffer_u8s8s32o32 + ( ( ic * KC ) + ( ( kr * MR ) + ( 96 ) ) ), d0 ); + // Last piece + last_piece = _mm512_castsi512_si256( a0 ); + _mm256_storeu_epi64( pack_a_buffer_u8s8s32o32 + ( ( ic * KC ) + ( ( kr * MR ) + ( 160 ) ) ), last_piece ); + + // Second half + b0 = _mm512_permutex2var_epi64( c01, selector7, e01 ); // 3rd 64 + c01 = _mm512_permutex2var_epi64( c01, selector8, e0 ); // 3rd 32 + d0 = _mm512_permutex2var_epi64( c0, selector9, e01 ); // 4th 64 + c0 = _mm512_permutex2var_epi64( c0, selector10, e0 ); // 4th 32 + + _mm512_storeu_epi64( pack_a_buffer_u8s8s32o32 + ( ( ic * KC ) + ( ( kr * MR ) + ( 192 ) ) ), b0 ); + _mm512_storeu_epi64( pack_a_buffer_u8s8s32o32 + ( ( ic * KC ) + ( ( kr * MR ) + ( 256 ) ) ) , c01 ); + _mm512_storeu_epi64( pack_a_buffer_u8s8s32o32 + ( ( ic * KC ) + ( ( kr * MR ) + ( 288 ) ) ), d0 ); + // Last piece + last_piece = _mm512_castsi512_si256( c0 ); + _mm256_storeu_epi64( pack_a_buffer_u8s8s32o32 + ( ( ic * KC ) + ( ( kr * MR ) + ( 352 ) ) ), last_piece ); + } + //TODO: Handle kc < 64 case, 48,32,16 + } + + if ( m_partial_pieces > 0 ) + { + if ( m_partial_pieces == 5 ) + { + packa_m5_k64_u8s8s32o32 + ( + pack_a_buffer_u8s8s32o32 + ( m_full_pieces_loop_limit * KC ), + a + ( lda * m_full_pieces_loop_limit ), lda, KC + ); + } + else if ( m_partial_pieces == 4 ) + { + packa_m4_k64_u8s8s32o32 + ( + pack_a_buffer_u8s8s32o32 + ( m_full_pieces_loop_limit * KC ), + a + ( lda * m_full_pieces_loop_limit ), lda, KC + ); + } + else if ( m_partial_pieces == 3 ) + { + packa_m3_k64_u8s8s32o32 + ( + pack_a_buffer_u8s8s32o32 + ( m_full_pieces_loop_limit * KC ), + a + ( lda * m_full_pieces_loop_limit ), lda, KC + ); + } + else if ( m_partial_pieces == 2 ) + { + packa_m2_k64_u8s8s32o32 + ( + pack_a_buffer_u8s8s32o32 + ( m_full_pieces_loop_limit * KC ), + a + ( lda * m_full_pieces_loop_limit ), lda, KC + ); + } + else if ( m_partial_pieces == 1 ) + { + packa_m1_k64_u8s8s32o32 + ( + pack_a_buffer_u8s8s32o32 + ( m_full_pieces_loop_limit * KC ), + a + ( lda * m_full_pieces_loop_limit ), lda, KC + ); + } + } + *rs_a = 4; + *cs_a = 24; +} + +void packa_m5_k64_u8s8s32o32 + ( + uint8_t* pack_a_buffer_u8s8s32o32, + const uint8_t* a, + const dim_t lda, + const dim_t KC + ) +{ + // Used for permuting the mm512i elements for use in vpdpbusd instruction. + // These are indexes of the format a0-a1-b0-b1-a2-a3-b2-b3 and a0-a1-a2-a3-b0-b1-b2-b3. + // Adding 4 int32 wise gives format a4-a5-b4-b5-a6-a7-b6-b7 and a4-a5-a6-a7-b4-b5-b6-b7. + __m512i selector1 = _mm512_setr_epi64( 0x0, 0x1, 0x8, 0x9, 0x2, 0x3, 0xA, 0xB ); + __m512i selector1_1 = _mm512_setr_epi64( 0x4, 0x5, 0xC, 0xD, 0x6, 0x7, 0xE, 0xF ); + __m512i selector2 = _mm512_setr_epi64( 0x0, 0x1, 0x2, 0x3, 0x8, 0x9, 0xA, 0xB ); + __m512i selector2_1 = _mm512_setr_epi64( 0x4, 0x5, 0x6, 0x7, 0xC, 0xD, 0xE, 0xF ); + + // First half. + __m512i selector3 = _mm512_setr_epi32( 0x0, 0x1, 0x2, 0x3, 0x10, 0x4, 0x5, 0x6, 0x7, 0x11, 0x8, 0x9, 0xA, 0xB, 0x12, 0xC); + __m512i selector4 = _mm512_setr_epi32( 0xD, 0xE, 0xF, 0x13, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0); + __m512i selector5 = _mm512_setr_epi32( 0x0, 0x1, 0x2, 0x3, 0x14, 0x4, 0x5, 0x6, 0x7, 0x15, 0x8, 0x9, 0xA, 0xB, 0x16, 0xC); + __m512i selector6 = _mm512_setr_epi32( 0xD, 0xE, 0xF, 0x17, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0); + + // Second half. + __m512i selector7 = _mm512_setr_epi32( 0x0, 0x1, 0x2, 0x3, 0x18, 0x4, 0x5, 0x6, 0x7, 0x19, 0x8, 0x9, 0xA, 0xB, 0x1A, 0xC); + __m512i selector8 = _mm512_setr_epi32( 0xD, 0xE, 0xF, 0x1B, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0); + __m512i selector9 = _mm512_setr_epi32( 0x0, 0x1, 0x2, 0x3, 0x1C, 0x4, 0x5, 0x6, 0x7, 0x1D, 0x8, 0x9, 0xA, 0xB, 0x1E, 0xC); + __m512i selector10 = _mm512_setr_epi32( 0xD, 0xE, 0xF, 0x1F, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0); + + __m512i a0; + __m512i b0; + __m512i c0; + __m512i d0; + __m512i e0; + __m512i a01; + __m512i c01; + __m128i last_piece; + + for ( dim_t kr = 0; kr < KC; kr += NR ) + { + // Rearrange for vpdpbusd, read 5 rows from A with 64 elements in each row. + a0 = _mm512_loadu_epi8( a + ( lda * 0 ) + kr ); + b0 = _mm512_loadu_epi8( a + ( lda * 1 ) + kr ); + c0 = _mm512_loadu_epi8( a + ( lda * 2 ) + kr ); + d0 = _mm512_loadu_epi8( a + ( lda * 3 ) + kr ); + e0 = _mm512_loadu_epi8( a + ( lda * 4 ) + kr ); + + a01 = _mm512_unpacklo_epi32( a0, b0 ); + a0 = _mm512_unpackhi_epi32( a0, b0 ); + + c01 = _mm512_unpacklo_epi32( c0, d0 ); + c0 = _mm512_unpackhi_epi32( c0, d0 ); + + b0 = _mm512_unpacklo_epi64( a01, c01 ); + a01 = _mm512_unpackhi_epi64( a01, c01 ); + + d0 = _mm512_unpacklo_epi64( a0, c0 ); + c01 = _mm512_unpackhi_epi64( a0, c0 ); + + a0 = _mm512_permutex2var_epi64( b0, selector1, a01 ); + c0 = _mm512_permutex2var_epi64( d0, selector1, c01 ); + b0 = _mm512_permutex2var_epi64( b0, selector1_1, a01 ); + d0 = _mm512_permutex2var_epi64( d0, selector1_1, c01 ); + + a01 = _mm512_permutex2var_epi64( a0, selector2, c0 ); // a[0] + c01 = _mm512_permutex2var_epi64( b0, selector2, d0 ); // a[2] + a0 = _mm512_permutex2var_epi64( a0, selector2_1, c0 ); // a[1] + c0 = _mm512_permutex2var_epi64( b0, selector2_1, d0 ); // a[3] + + // First half + b0 = _mm512_permutex2var_epi32( a01, selector3, e0 ); + a01 = _mm512_permutex2var_epi32( a01, selector4, e0 ); + d0 = _mm512_permutex2var_epi32( a0, selector5, e0 ); + a0 = _mm512_permutex2var_epi32( a0, selector6, e0 ); + + _mm512_storeu_epi64( pack_a_buffer_u8s8s32o32 + ( ( kr * 5 ) + ( 0 ) ), b0 ); + _mm512_storeu_epi64( pack_a_buffer_u8s8s32o32 + ( ( kr * 5 ) + ( 64 ) ) , a01 ); + _mm512_storeu_epi64( pack_a_buffer_u8s8s32o32 + ( ( kr * 5 ) + ( 80 ) ), d0 ); + // Last piece + last_piece = _mm512_castsi512_si128( a0 ); + _mm_storeu_epi64( pack_a_buffer_u8s8s32o32 + ( ( kr * 5 ) + ( 144 ) ), last_piece ); + + // Second half + b0 = _mm512_permutex2var_epi32( c01, selector7, e0 ); + c01 = _mm512_permutex2var_epi32( c01, selector8, e0 ); + d0 = _mm512_permutex2var_epi32( c0, selector9, e0 ); + c0 = _mm512_permutex2var_epi32( c0, selector10, e0 ); + + _mm512_storeu_epi64( pack_a_buffer_u8s8s32o32 + ( ( kr * 5 ) + ( 160 ) ), b0 ); + _mm512_storeu_epi64( pack_a_buffer_u8s8s32o32 + ( ( kr * 5 ) + ( 224 ) ) , c01 ); + _mm512_storeu_epi64( pack_a_buffer_u8s8s32o32 + ( ( kr * 5 ) + ( 240 ) ), d0 ); + // Last piece + last_piece = _mm512_castsi512_si128( c0 ); + _mm_storeu_epi64( pack_a_buffer_u8s8s32o32 + ( ( kr * 5 ) + ( 304 ) ), last_piece ); + } +} + +void packa_m4_k64_u8s8s32o32 + ( + uint8_t* pack_a_buffer_u8s8s32o32, + const uint8_t* a, + const dim_t lda, + const dim_t KC + ) +{ + // Used for permuting the mm512i elements for use in vpdpbusd instruction. + // These are indexes of the format a0-a1-b0-b1-a2-a3-b2-b3 and a0-a1-a2-a3-b0-b1-b2-b3. + // Adding 4 int32 wise gives format a4-a5-b4-b5-a6-a7-b6-b7 and a4-a5-a6-a7-b4-b5-b6-b7. + __m512i selector1 = _mm512_setr_epi64( 0x0, 0x1, 0x8, 0x9, 0x2, 0x3, 0xA, 0xB ); + __m512i selector1_1 = _mm512_setr_epi64( 0x4, 0x5, 0xC, 0xD, 0x6, 0x7, 0xE, 0xF ); + __m512i selector2 = _mm512_setr_epi64( 0x0, 0x1, 0x2, 0x3, 0x8, 0x9, 0xA, 0xB ); + __m512i selector2_1 = _mm512_setr_epi64( 0x4, 0x5, 0x6, 0x7, 0xC, 0xD, 0xE, 0xF ); + + __m512i a0; + __m512i b0; + __m512i c0; + __m512i d0; + __m512i a01; + __m512i c01; + + for ( dim_t kr = 0; kr < KC; kr += NR ) + { + // Rearrange for vpdpbusd, read 4 rows from A with 64 elements in each row. + a0 = _mm512_loadu_epi8( a + ( lda * 0 ) + kr ); + b0 = _mm512_loadu_epi8( a + ( lda * 1 ) + kr ); + c0 = _mm512_loadu_epi8( a + ( lda * 2 ) + kr ); + d0 = _mm512_loadu_epi8( a + ( lda * 3 ) + kr ); + + a01 = _mm512_unpacklo_epi32( a0, b0 ); + a0 = _mm512_unpackhi_epi32( a0, b0 ); + + c01 = _mm512_unpacklo_epi32( c0, d0 ); + c0 = _mm512_unpackhi_epi32( c0, d0 ); + + b0 = _mm512_unpacklo_epi64( a01, c01 ); + a01 = _mm512_unpackhi_epi64( a01, c01 ); + + d0 = _mm512_unpacklo_epi64( a0, c0 ); + c01 = _mm512_unpackhi_epi64( a0, c0 ); + + a0 = _mm512_permutex2var_epi64( b0, selector1, a01 ); + c0 = _mm512_permutex2var_epi64( d0, selector1, c01 ); + b0 = _mm512_permutex2var_epi64( b0, selector1_1, a01 ); + d0 = _mm512_permutex2var_epi64( d0, selector1_1, c01 ); + + a01 = _mm512_permutex2var_epi64( a0, selector2, c0 ); // a[0] + c01 = _mm512_permutex2var_epi64( b0, selector2, d0 ); // a[2] + a0 = _mm512_permutex2var_epi64( a0, selector2_1, c0 ); // a[1] + c0 = _mm512_permutex2var_epi64( b0, selector2_1, d0 ); // a[3] + + _mm512_storeu_epi64( pack_a_buffer_u8s8s32o32 + ( ( kr * 4 ) + ( 0 ) ), a01 ); + _mm512_storeu_epi64( pack_a_buffer_u8s8s32o32 + ( ( kr * 4 ) + ( 64 ) ) , a0 ); + _mm512_storeu_epi64( pack_a_buffer_u8s8s32o32 + ( ( kr * 4 ) + ( 128 ) ), c01 ); + _mm512_storeu_epi64( pack_a_buffer_u8s8s32o32 + ( ( kr * 4 ) + ( 192 ) ), c0 ); + } +} + +void packa_m3_k64_u8s8s32o32 + ( + uint8_t* pack_a_buffer_u8s8s32o32, + const uint8_t* a, + const dim_t lda, + const dim_t KC + ) +{ + // Used for permuting the mm512i elements for use in vpdpbusd instruction. + // These are indexes of the format a0-a1-b0-b1-a2-a3-b2-b3 and a0-a1-a2-a3-b0-b1-b2-b3. + // Adding 4 int32 wise gives format a4-a5-b4-b5-a6-a7-b6-b7 and a4-a5-a6-a7-b4-b5-b6-b7. + __m512i selector1 = _mm512_setr_epi64( 0x0, 0x1, 0x8, 0x9, 0x2, 0x3, 0xA, 0xB ); + __m512i selector1_1 = _mm512_setr_epi64( 0x4, 0x5, 0xC, 0xD, 0x6, 0x7, 0xE, 0xF ); + + // First half + __m512i selector3 = _mm512_setr_epi32( 0x0, 0x1, 0x10, 0x2, 0x3, 0x11, 0x4, 0x5, 0x12, 0x6, 0x7, 0x13, 0x8, 0x9, 0x14, 0xA ); + __m512i selector4 = _mm512_setr_epi32( 0xB, 0x15, 0xC, 0xD, 0x16, 0xE, 0xF, 0x17, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0 ); + + // Second half + __m512i selector5 = _mm512_setr_epi32( 0x0, 0x1, 0x18, 0x2, 0x3, 0x19, 0x4, 0x5, 0x1A, 0x6, 0x7, 0x1B, 0x8, 0x9, 0x1C, 0xA ); + __m512i selector6 = _mm512_setr_epi32( 0xB, 0x1D, 0xC, 0xD, 0x1E, 0xE, 0xF, 0x1F, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0 ); + + __m512i a0; + __m512i b0; + __m512i c0; + __m512i a01; + __m256i last_piece; + + for ( dim_t kr = 0; kr < KC; kr += NR ) + { + // Rearrange for vpdpbusd, read 3 rows from A with 64 elements in each row. + a0 = _mm512_loadu_epi8( a + ( lda * 0 ) + kr ); + b0 = _mm512_loadu_epi8( a + ( lda * 1 ) + kr ); + c0 = _mm512_loadu_epi8( a + ( lda * 2 ) + kr ); + + a01 = _mm512_unpacklo_epi32( a0, b0 ); + a0 = _mm512_unpackhi_epi32( a0, b0 ); + + b0 = _mm512_permutex2var_epi64( a01, selector1, a0 ); // a[0] + a01 = _mm512_permutex2var_epi64( a01, selector1_1, a0 ); // a[1] + + a0 = _mm512_permutex2var_epi32( b0, selector3, c0 ); + b0 = _mm512_permutex2var_epi32( b0, selector4, c0 ); + + _mm512_storeu_epi64( pack_a_buffer_u8s8s32o32 + ( ( kr * 3 ) + ( 0 ) ), a0 ); + _mm512_storeu_epi64( pack_a_buffer_u8s8s32o32 + ( ( kr * 3 ) + ( 64 ) ) , b0 ); + + a0 = _mm512_permutex2var_epi32( a01, selector5, c0 ); + b0 = _mm512_permutex2var_epi32( a01, selector6, c0 ); + + _mm512_storeu_epi64( pack_a_buffer_u8s8s32o32 + ( ( kr * 3 ) + ( 96 ) ), a0 ); + // Last piece + last_piece = _mm512_castsi512_si256( b0 ); + _mm256_storeu_epi64( pack_a_buffer_u8s8s32o32 + ( ( kr * 3 ) + ( 160 ) ), last_piece ); + } +} + +void packa_m2_k64_u8s8s32o32 + ( + uint8_t* pack_a_buffer_u8s8s32o32, + const uint8_t* a, + const dim_t lda, + const dim_t KC + ) +{ + // Used for permuting the mm512i elements for use in vpdpbusd instruction. + // These are indexes of the format a0-a1-b0-b1-a2-a3-b2-b3 and a0-a1-a2-a3-b0-b1-b2-b3. + // Adding 4 int32 wise gives format a4-a5-b4-b5-a6-a7-b6-b7 and a4-a5-a6-a7-b4-b5-b6-b7. + __m512i selector1 = _mm512_setr_epi64( 0x0, 0x1, 0x8, 0x9, 0x2, 0x3, 0xA, 0xB ); + __m512i selector1_1 = _mm512_setr_epi64( 0x4, 0x5, 0xC, 0xD, 0x6, 0x7, 0xE, 0xF ); + + __m512i a0; + __m512i b0; + __m512i a01; + + for ( dim_t kr = 0; kr < KC; kr += NR ) + { + // Rearrange for vpdpbusd, read 2 rows from A with 64 elements in each row. + a0 = _mm512_loadu_epi8( a + ( lda * 0 ) + kr ); + b0 = _mm512_loadu_epi8( a + ( lda * 1 ) + kr ); + + a01 = _mm512_unpacklo_epi32( a0, b0 ); + a0 = _mm512_unpackhi_epi32( a0, b0 ); + + b0 = _mm512_permutex2var_epi64( a01, selector1, a0 ); // a[0] + a01 = _mm512_permutex2var_epi64( a01, selector1_1, a0 ); // a[1] + + _mm512_storeu_epi64( pack_a_buffer_u8s8s32o32 + ( ( kr * 2 ) + ( 0 ) ), b0 ); + _mm512_storeu_epi64( pack_a_buffer_u8s8s32o32 + ( ( kr * 2 ) + ( 64 ) ) , a01 ); + } +} + +void packa_m1_k64_u8s8s32o32 + ( + uint8_t* pack_a_buffer_u8s8s32o32, + const uint8_t* a, + const dim_t lda, + const dim_t KC + ) +{ + __m512i a0; + + for ( dim_t kr = 0; kr < KC; kr += NR ) + { + // Rearrange for vpdpbusd, read 1 row from A with 64 elements in each row. + a0 = _mm512_loadu_epi8( a + ( lda * 0 ) + kr ); + + _mm512_storeu_epi64( pack_a_buffer_u8s8s32o32 + ( ( kr * 1 ) + ( 0 ) ), a0 ); + } +} +#endif diff --git a/addon/aocl_gemm/kernels/u8s8s32/lpgemm_packb.h b/addon/aocl_gemm/kernels/u8s8s32/lpgemm_packb.h new file mode 100644 index 0000000000..3f310c0a48 --- /dev/null +++ b/addon/aocl_gemm/kernels/u8s8s32/lpgemm_packb.h @@ -0,0 +1,65 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef BLIS_GEMM_INT8_PACKB +#define BLIS_GEMM_INT8_PACKB + +BLIS_INLINE dim_t get_packb_u8s8s32o32_min_NR() +{ + // This is the minimum NR' required for use in u8s8s32 kernels. The idea + // here is that since k needs to be a multiple of 4 (VNNI instr), NR'=16 + // results in total of 4 * NR' = 64 bytes to be loaded, which fits in 1 ZMM + // register. Thus the smallest n fringe kernel dimension has n=16, and thus + // any rounding for buffer sizes should be to 16. + return 16; +} + +void get_packb_nr64_u8s8s32o32_strides + ( + dim_t* rs_b, + dim_t* cs_b + ); + +void packb_nr64_u8s8s32o32 + ( + int8_t* pack_b_buffer_u8s8s32o32, + const int8_t* b, + const dim_t ldb, + const dim_t NC, + const dim_t KC, + dim_t* rs_b, + dim_t* cs_b + ); + +#endif //BLIS_GEMM_INT8_PACKB diff --git a/addon/aocl_gemm/kernels/u8s8s32/lpgemm_packb_amd512vnni.c b/addon/aocl_gemm/kernels/u8s8s32/lpgemm_packb_amd512vnni.c new file mode 100644 index 0000000000..d388c476e9 --- /dev/null +++ b/addon/aocl_gemm/kernels/u8s8s32/lpgemm_packb_amd512vnni.c @@ -0,0 +1,794 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include + +#include "blis.h" +#include "lpgemm_packb.h" + +#define NR 64 + +void get_packb_nr64_u8s8s32o32_strides + ( + dim_t* rs_b, + dim_t* cs_b + ) +{ + *rs_b = NR * 4; + *cs_b = NR; +} + +#ifdef BLIS_KERNELS_ZEN4 +void packb_nrlt16_u8s8s32o32 + ( + int8_t* pack_b_buffer_u8s8s32o32, + const int8_t* b, + const dim_t ldb, + const dim_t KC, + const dim_t n0_partial_rem + ); + +void packb_nr16_u8s8s32o32 + ( + int8_t* pack_b_buffer_u8s8s32o32, + const int8_t* b, + const dim_t ldb, + const dim_t KC + ); + +void packb_nr32_u8s8s32o32 + ( + int8_t* pack_b_buffer_u8s8s32o32, + const int8_t* b, + const dim_t ldb, + const dim_t KC + ); + +void packb_nr48_u8s8s32o32 + ( + int8_t* pack_b_buffer_u8s8s32o32, + const int8_t* b, + const dim_t ldb, + const dim_t KC + ); + +void packb_nr64_u8s8s32o32 + ( + int8_t* pack_b_buffer_u8s8s32o32, + const int8_t* b, + const dim_t ldb, + const dim_t NC, + const dim_t KC, + dim_t* rs_b, + dim_t* cs_b + ) +{ + // Used for permuting the mm512i elements for use in vpdpbusd instruction. + // These are indexes of the format a0-a1-b0-b1-a2-a3-b2-b3 and a0-a1-a2-a3-b0-b1-b2-b3. + // Adding int32 wise all4 gives format a4-a5-b4-b5-a6-a7-b6-b7 and a4-a5-a6-a7-b4-b5-b6-b7. + __m512i selector1 = _mm512_setr_epi64( 0x0, 0x1, 0x8, 0x9, 0x2, 0x3, 0xA, 0xB ); + __m512i selector1_1 = _mm512_setr_epi64( 0x4, 0x5, 0xC, 0xD, 0x6, 0x7, 0xE, 0xF ); + + __m512i selector2 = _mm512_setr_epi64( 0x0, 0x1, 0x2, 0x3, 0x8, 0x9, 0xA, 0xB ); + __m512i selector2_1 = _mm512_setr_epi64( 0x4, 0x5, 0x6, 0x7, 0xC, 0xD, 0xE, 0xF ); + + dim_t n_full_pieces = NC / NR; + dim_t n_full_pieces_loop_limit = n_full_pieces * NR; + dim_t n_partial_pieces = NC % NR; + + dim_t k_full_pieces_blks = KC / 4; + dim_t k_full_pieces = k_full_pieces_blks * 4; + dim_t k_partial_pieces = KC % 4; + + // KC when not multiple of 4 will have padding to make it multiple of 4 in packed buffer. + dim_t KC_updated = KC; + if ( k_partial_pieces > 0 ) + { + KC_updated += ( 4 - k_partial_pieces ); + } + + __m512i a0; + __m512i b0; + __m512i c0; + __m512i d0; + __m512i a01; + __m512i c01; + + for ( dim_t jc = 0; jc < n_full_pieces_loop_limit; jc += NR ) + { + for ( dim_t kr = 0; kr < k_full_pieces; kr += 4 ) + { + // Rearrange for vpdpbusd, read 4 rows from B with 64 elements in each row. + a0 = _mm512_loadu_epi8( b + ( ldb * ( kr + 0 ) ) + jc ); + b0 = _mm512_loadu_epi8( b + ( ldb * ( kr + 1 ) ) + jc ); + c0 = _mm512_loadu_epi8( b + ( ldb * ( kr + 2 ) ) + jc ); + d0 = _mm512_loadu_epi8( b + ( ldb * ( kr + 3 ) ) + jc ); + + a01 = _mm512_unpacklo_epi8( a0, b0 ); + a0 = _mm512_unpackhi_epi8( a0, b0 ); + + c01 = _mm512_unpacklo_epi8( c0, d0 ); + c0 = _mm512_unpackhi_epi8( c0, d0 ); + + b0 = _mm512_unpacklo_epi16( a01, c01 ); + a01 = _mm512_unpackhi_epi16( a01, c01 ); + + d0 = _mm512_unpacklo_epi16( a0, c0 ); + c01 = _mm512_unpackhi_epi16( a0, c0 ); + + a0 = _mm512_permutex2var_epi64( b0, selector1, a01 ); + c0 = _mm512_permutex2var_epi64( d0, selector1, c01 ); + b0 = _mm512_permutex2var_epi64( b0, selector1_1, a01 ); + d0 = _mm512_permutex2var_epi64( d0, selector1_1, c01 ); + + a01 = _mm512_permutex2var_epi64( a0, selector2, c0 ); // b[0] + c01 = _mm512_permutex2var_epi64( b0, selector2, d0 ); // b[2] + a0 = _mm512_permutex2var_epi64( a0, selector2_1, c0 ); // b[1] + c0 = _mm512_permutex2var_epi64( b0, selector2_1, d0 ); // b[3] + + _mm512_storeu_epi64( pack_b_buffer_u8s8s32o32 + ( ( jc * KC_updated ) + ( ( kr + 0 ) * NR ) ), a01 ); + _mm512_storeu_epi64( pack_b_buffer_u8s8s32o32 + ( ( jc * KC_updated ) + ( ( kr + 1 ) * NR ) ) , a0 ); + _mm512_storeu_epi64( pack_b_buffer_u8s8s32o32 + ( ( jc * KC_updated ) + ( ( kr + 2 ) * NR ) ), c01 ); + _mm512_storeu_epi64( pack_b_buffer_u8s8s32o32 + ( ( jc * KC_updated ) + ( ( kr + 3 ) * NR ) ), c0 ); + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + if ( k_partial_pieces == 3 ) + { + a0 = _mm512_loadu_epi8( b + ( ldb * ( k_full_pieces + 0 ) ) + jc ); + b0 = _mm512_loadu_epi8( b + ( ldb * ( k_full_pieces + 1 ) ) + jc ); + c0 = _mm512_loadu_epi8( b + ( ldb * ( k_full_pieces + 2 ) ) + jc ); + d0 = _mm512_setzero_si512(); + + } + else if( k_partial_pieces == 2 ) + { + a0 = _mm512_loadu_epi8( b + ( ldb * ( k_full_pieces + 0 ) ) + jc ); + b0 = _mm512_loadu_epi8( b + ( ldb * ( k_full_pieces + 1 ) ) + jc ); + c0 = _mm512_setzero_si512(); + d0 = _mm512_setzero_si512(); + } + else //k_partial_pieces == 1 + { + a0 = _mm512_loadu_epi8( b + ( ldb * ( k_full_pieces + 0 ) ) + jc ); + b0 = _mm512_setzero_si512(); + c0 = _mm512_setzero_si512(); + d0 = _mm512_setzero_si512(); + } + + a01 = _mm512_unpacklo_epi8( a0, b0 ); + a0 = _mm512_unpackhi_epi8( a0, b0 ); + + c01 = _mm512_unpacklo_epi8( c0, d0 ); + c0 = _mm512_unpackhi_epi8( c0, d0 ); + + b0 = _mm512_unpacklo_epi16( a01, c01 ); + a01 = _mm512_unpackhi_epi16( a01, c01 ); + + d0 = _mm512_unpacklo_epi16( a0, c0 ); + c01 = _mm512_unpackhi_epi16( a0, c0 ); + + a0 = _mm512_permutex2var_epi64( b0, selector1, a01 ); + c0 = _mm512_permutex2var_epi64( d0, selector1, c01 ); + b0 = _mm512_permutex2var_epi64( b0, selector1_1, a01 ); + d0 = _mm512_permutex2var_epi64( d0, selector1_1, c01 ); + + a01 = _mm512_permutex2var_epi64( a0, selector2, c0 ); // b[0] + c01 = _mm512_permutex2var_epi64( b0, selector2, d0 ); // b[2] + a0 = _mm512_permutex2var_epi64( a0, selector2_1, c0 ); // b[1] + c0 = _mm512_permutex2var_epi64( b0, selector2_1, d0 ); // b[3] + + _mm512_storeu_epi64( pack_b_buffer_u8s8s32o32 + ( ( jc * KC_updated ) + ( ( k_full_pieces + 0 ) * NR ) ), a01 ); + _mm512_storeu_epi64( pack_b_buffer_u8s8s32o32 + ( ( jc * KC_updated ) + ( ( k_full_pieces + 1 ) * NR ) ) , a0 ); + _mm512_storeu_epi64( pack_b_buffer_u8s8s32o32 + ( ( jc * KC_updated ) + ( ( k_full_pieces + 2 ) * NR ) ), c01 ); + _mm512_storeu_epi64( pack_b_buffer_u8s8s32o32 + ( ( jc * KC_updated ) + ( ( k_full_pieces + 3 ) * NR ) ), c0 ); + } + } + + // Contiguous packing of fringe panel (n` < NR). + if ( n_partial_pieces > 0 ) + { + dim_t n0_partial_rem = n_partial_pieces % 16; + dim_t n0_partial_pack = 0; + + // Split into multiple smaller fringe kernels, so as to maximize + // vectorization after packing. Any n0 < NR(64) can be expressed + // as n0 = 48 + n` / n0 = 32 + n` / n0 = 16 + n`, where n` < 16. + dim_t n0_48 = n_partial_pieces / 48; + dim_t n0_32 = n_partial_pieces / 32; + dim_t n0_16 = n_partial_pieces / 16; + + if ( n0_48 == 1 ) + { + packb_nr48_u8s8s32o32 + ( + ( pack_b_buffer_u8s8s32o32 + ( n_full_pieces_loop_limit * KC_updated ) ), + ( b + n_full_pieces_loop_limit ), ldb, KC + ); + + n0_partial_pack = 48; + } + else if ( n0_32 == 1 ) + { + packb_nr32_u8s8s32o32 + ( + ( pack_b_buffer_u8s8s32o32 + ( n_full_pieces_loop_limit * KC_updated ) ), + ( b + n_full_pieces_loop_limit ), ldb, KC + ); + + n0_partial_pack = 32; + } + else if ( n0_16 == 1 ) + { + packb_nr16_u8s8s32o32 + ( + ( pack_b_buffer_u8s8s32o32 + ( n_full_pieces_loop_limit * KC_updated ) ), + ( b + n_full_pieces_loop_limit ), ldb, KC + ); + + n0_partial_pack = 16; + } + + if ( n0_partial_rem > 0 ) + { + packb_nrlt16_u8s8s32o32 + ( + ( pack_b_buffer_u8s8s32o32 + ( n_full_pieces_loop_limit * KC_updated ) + + ( n0_partial_pack * KC_updated ) ), + ( b + n_full_pieces_loop_limit + n0_partial_pack ), ldb, KC, + n0_partial_rem + ); + } + } + *rs_b = NR * 4; + *cs_b = NR; +} + +void packb_nr48_u8s8s32o32 + ( + int8_t* pack_b_buffer_u8s8s32o32, + const int8_t* b, + const dim_t ldb, + const dim_t KC + ) +{ + dim_t kr_new = 0; + + dim_t k_full_pieces_blks = KC / 4; + dim_t k_full_pieces = k_full_pieces_blks * 4; + dim_t k_partial_pieces = KC % 4; + + __m256i a0_32; + __m256i b0_32; + __m256i c0_32; + __m256i d0_32; + __m256i a01_32; + __m256i c01_32; + __m512i a0_zmm; + __m512i b0_zmm; + __m128i a0_16; + __m128i b0_16; + __m128i c0_16; + __m128i d0_16; + __m128i a01_16; + __m128i c01_16; + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 4 ) + { + // Rearrange for vpdpbusd, read 4 rows from B with 32 elements in each row. + a0_32 = _mm256_loadu_epi8( b + ( ldb * ( kr + 0 ) ) ); + b0_32 = _mm256_loadu_epi8( b + ( ldb * ( kr + 1 ) ) ); + c0_32 = _mm256_loadu_epi8( b + ( ldb * ( kr + 2 ) ) ); + d0_32 = _mm256_loadu_epi8( b + ( ldb * ( kr + 3 ) ) ); + + a01_32 = _mm256_unpacklo_epi8( a0_32, b0_32 ); + a0_32 = _mm256_unpackhi_epi8( a0_32, b0_32 ); + + c01_32 = _mm256_unpacklo_epi8( c0_32, d0_32 ); + c0_32 = _mm256_unpackhi_epi8( c0_32, d0_32 ); + + b0_32 = _mm256_unpacklo_epi16( a01_32, c01_32 ); + a01_32 = _mm256_unpackhi_epi16( a01_32, c01_32 ); + + d0_32 = _mm256_unpacklo_epi16( a0_32, c0_32 ); + c01_32 = _mm256_unpackhi_epi16( a0_32, c0_32 ); + + a0_32 = _mm256_shuffle_i32x4( b0_32, a01_32, 0x0 ); // 0 elem + c0_32 = _mm256_shuffle_i32x4( b0_32, a01_32, 0x3 ); // 2 elem + b0_32 = _mm256_shuffle_i32x4( d0_32, c01_32, 0x0 ); // 1 elem + d0_32 = _mm256_shuffle_i32x4( d0_32, c01_32, 0x3 ); // 3 elem + + a0_zmm = _mm512_castsi256_si512( a0_32 ); + a0_zmm = _mm512_inserti32x8( a0_zmm, b0_32, 0x1 ); + b0_zmm = _mm512_castsi256_si512( c0_32 ); + b0_zmm = _mm512_inserti32x8( b0_zmm, d0_32, 0x1 ); + + // First 4x32 elements. + _mm512_storeu_epi64( pack_b_buffer_u8s8s32o32 + ( ( kr_new + 0 ) * NR ), a0_zmm ); + _mm512_storeu_epi64( pack_b_buffer_u8s8s32o32 + ( ( kr_new + 1 ) * NR ), b0_zmm ); + + // Rearrange for vpdpbusd, read 4 rows from B with next 16 elements in each row. + a0_16 = _mm_loadu_epi8( b + ( ldb * ( kr + 0 ) ) + ( 32 ) ); + b0_16 = _mm_loadu_epi8( b + ( ldb * ( kr + 1 ) ) + ( 32 ) ); + c0_16 = _mm_loadu_epi8( b + ( ldb * ( kr + 2 ) ) + ( 32 ) ); + d0_16 = _mm_loadu_epi8( b + ( ldb * ( kr + 3 ) ) + ( 32 ) ); + + a01_16 = _mm_unpacklo_epi8( a0_16, b0_16 ); + a0_16 = _mm_unpackhi_epi8( a0_16, b0_16 ); + + c01_16 = _mm_unpacklo_epi8( c0_16, d0_16 ); + c0_16 = _mm_unpackhi_epi8( c0_16, d0_16 ); + + b0_16 = _mm_unpacklo_epi16( a01_16, c01_16 ); // 0 elem + a01_16 = _mm_unpackhi_epi16( a01_16, c01_16 ); // 1 elem + d0_16 = _mm_unpacklo_epi16( a0_16, c0_16 ); // 2 elem + c01_16 = _mm_unpackhi_epi16( a0_16, c0_16 ); // 3 elem + + a0_zmm = _mm512_castsi128_si512( b0_16 ); + a0_zmm = _mm512_inserti32x4( a0_zmm, a01_16, 0x1 ); + a0_zmm = _mm512_inserti32x4( a0_zmm, d0_16, 0x2 ); + a0_zmm = _mm512_inserti32x4( a0_zmm, c01_16, 0x3 ); + + // Last 4x16 elements. + _mm512_storeu_epi64( pack_b_buffer_u8s8s32o32 + ( ( kr_new + 2 ) * NR ), a0_zmm ); + + // The 4th 16byte chunk will be ignored, since its not part of the original data, + // but is here due to the packing in 4 16byte chunks format. + kr_new += 3; + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + if ( k_partial_pieces == 3 ) + { + a0_32 = _mm256_loadu_epi8( b + ( ldb * ( k_full_pieces + 0 ) ) ); + b0_32 = _mm256_loadu_epi8( b + ( ldb * ( k_full_pieces + 1 ) ) ); + c0_32 = _mm256_loadu_epi8( b + ( ldb * ( k_full_pieces + 2 ) ) ); + d0_32 = _mm256_setzero_si256(); + + a0_16 = _mm_loadu_epi8( b + ( ldb * ( k_full_pieces + 0 ) ) + ( 32 ) ); + b0_16 = _mm_loadu_epi8( b + ( ldb * ( k_full_pieces + 1 ) ) + ( 32 ) ); + c0_16 = _mm_loadu_epi8( b + ( ldb * ( k_full_pieces + 2 ) ) + ( 32 ) ); + d0_16 = _mm_setzero_si128(); + + } + else if( k_partial_pieces == 2 ) + { + a0_32 = _mm256_loadu_epi8( b + ( ldb * ( k_full_pieces + 0 ) ) ); + b0_32 = _mm256_loadu_epi8( b + ( ldb * ( k_full_pieces + 1 ) ) ); + c0_32 = _mm256_setzero_si256(); + d0_32 = _mm256_setzero_si256(); + + a0_16 = _mm_loadu_epi8( b + ( ldb * ( k_full_pieces + 0 ) ) + ( 32 ) ); + b0_16 = _mm_loadu_epi8( b + ( ldb * ( k_full_pieces + 1 ) ) + ( 32 ) ); + c0_16 = _mm_setzero_si128(); + d0_16 = _mm_setzero_si128(); + } + else //k_partial_pieces == 1 + { + a0_32 = _mm256_loadu_epi8( b + ( ldb * ( k_full_pieces + 0 ) ) ); + b0_32 = _mm256_setzero_si256(); + c0_32 = _mm256_setzero_si256(); + d0_32 = _mm256_setzero_si256(); + + a0_16 = _mm_loadu_epi8( b + ( ldb * ( k_full_pieces + 0 ) ) + ( 32 ) ); + b0_16 = _mm_setzero_si128(); + c0_16 = _mm_setzero_si128(); + d0_16 = _mm_setzero_si128(); + } + + a01_32 = _mm256_unpacklo_epi8( a0_32, b0_32 ); + a0_32 = _mm256_unpackhi_epi8( a0_32, b0_32 ); + + c01_32 = _mm256_unpacklo_epi8( c0_32, d0_32 ); + c0_32 = _mm256_unpackhi_epi8( c0_32, d0_32 ); + + b0_32 = _mm256_unpacklo_epi16( a01_32, c01_32 ); + a01_32 = _mm256_unpackhi_epi16( a01_32, c01_32 ); + + d0_32 = _mm256_unpacklo_epi16( a0_32, c0_32 ); + c01_32 = _mm256_unpackhi_epi16( a0_32, c0_32 ); + + a0_32 = _mm256_shuffle_i32x4( b0_32, a01_32, 0x0 ); // 0 elem + c0_32 = _mm256_shuffle_i32x4( b0_32, a01_32, 0x3 ); // 2 elem + b0_32 = _mm256_shuffle_i32x4( d0_32, c01_32, 0x0 ); // 1 elem + d0_32 = _mm256_shuffle_i32x4( d0_32, c01_32, 0x3 ); // 3 elem + + a0_zmm = _mm512_castsi256_si512( a0_32 ); + a0_zmm = _mm512_inserti32x8( a0_zmm, b0_32, 0x1 ); + b0_zmm = _mm512_castsi256_si512( c0_32 ); + b0_zmm = _mm512_inserti32x8( b0_zmm, d0_32, 0x1 ); + + // First 4x32 elements. + _mm512_storeu_epi64( pack_b_buffer_u8s8s32o32 + ( ( kr_new + 0 ) * NR ), a0_zmm ); + _mm512_storeu_epi64( pack_b_buffer_u8s8s32o32 + ( ( kr_new + 1 ) * NR ), b0_zmm ); + + a01_16 = _mm_unpacklo_epi8( a0_16, b0_16 ); + a0_16 = _mm_unpackhi_epi8( a0_16, b0_16 ); + + c01_16 = _mm_unpacklo_epi8( c0_16, d0_16 ); + c0_16 = _mm_unpackhi_epi8( c0_16, d0_16 ); + + b0_16 = _mm_unpacklo_epi16( a01_16, c01_16 ); // 0 elem + a01_16 = _mm_unpackhi_epi16( a01_16, c01_16 ); // 1 elem + d0_16 = _mm_unpacklo_epi16( a0_16, c0_16 ); // 2 elem + c01_16 = _mm_unpackhi_epi16( a0_16, c0_16 ); // 3 elem + + a0_zmm = _mm512_castsi128_si512( b0_16 ); + a0_zmm = _mm512_inserti32x4( a0_zmm, a01_16, 0x1 ); + a0_zmm = _mm512_inserti32x4( a0_zmm, d0_16, 0x2 ); + a0_zmm = _mm512_inserti32x4( a0_zmm, c01_16, 0x3 ); + + // Last 4x16 elements. + _mm512_storeu_epi64( pack_b_buffer_u8s8s32o32 + ( ( kr_new + 2 ) * NR ), a0_zmm ); + } +} + +void packb_nr32_u8s8s32o32 + ( + int8_t* pack_b_buffer_u8s8s32o32, + const int8_t* b, + const dim_t ldb, + const dim_t KC + ) +{ + dim_t kr_new = 0; + + dim_t k_full_pieces_blks = KC / 4; + dim_t k_full_pieces = k_full_pieces_blks * 4; + dim_t k_partial_pieces = KC % 4; + + __m256i a0_32; + __m256i b0_32; + __m256i c0_32; + __m256i d0_32; + __m256i a01_32; + __m256i c01_32; + __m512i a0_zmm; + __m512i b0_zmm; + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 4 ) + { + // Rearrange for vpdpbusd, read 4 rows from B with 32 elements in each row. + a0_32 = _mm256_loadu_epi8( b + ( ldb * ( kr + 0 ) ) ); + b0_32 = _mm256_loadu_epi8( b + ( ldb * ( kr + 1 ) ) ); + c0_32 = _mm256_loadu_epi8( b + ( ldb * ( kr + 2 ) ) ); + d0_32 = _mm256_loadu_epi8( b + ( ldb * ( kr + 3 ) ) ); + + a01_32 = _mm256_unpacklo_epi8( a0_32, b0_32 ); + a0_32 = _mm256_unpackhi_epi8( a0_32, b0_32 ); + + c01_32 = _mm256_unpacklo_epi8( c0_32, d0_32 ); + c0_32 = _mm256_unpackhi_epi8( c0_32, d0_32 ); + + b0_32 = _mm256_unpacklo_epi16( a01_32, c01_32 ); + a01_32 = _mm256_unpackhi_epi16( a01_32, c01_32 ); + + d0_32 = _mm256_unpacklo_epi16( a0_32, c0_32 ); + c01_32 = _mm256_unpackhi_epi16( a0_32, c0_32 ); + + a0_32 = _mm256_shuffle_i32x4( b0_32, a01_32, 0x0 ); // 0 elem + c0_32 = _mm256_shuffle_i32x4( b0_32, a01_32, 0x3 ); // 2 elem + b0_32 = _mm256_shuffle_i32x4( d0_32, c01_32, 0x0 ); // 1 elem + d0_32 = _mm256_shuffle_i32x4( d0_32, c01_32, 0x3 ); // 3 elem + + a0_zmm = _mm512_castsi256_si512( a0_32 ); + a0_zmm = _mm512_inserti32x8( a0_zmm, b0_32, 0x1 ); + b0_zmm = _mm512_castsi256_si512( c0_32 ); + b0_zmm = _mm512_inserti32x8( b0_zmm, d0_32, 0x1 ); + + // First 4x32 elements. + _mm512_storeu_epi64( pack_b_buffer_u8s8s32o32 + ( ( kr_new + 0 ) * NR ), a0_zmm ); + _mm512_storeu_epi64( pack_b_buffer_u8s8s32o32 + ( ( kr_new + 1 ) * NR ), b0_zmm ); + + // The 3rd and 4th 16byte chunk will be ignored, since its not part of the original data, + // but is here due to the packing in 4 16byte chunks format. + kr_new += 2; + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + if ( k_partial_pieces == 3 ) + { + a0_32 = _mm256_loadu_epi8( b + ( ldb * ( k_full_pieces + 0 ) ) ); + b0_32 = _mm256_loadu_epi8( b + ( ldb * ( k_full_pieces + 1 ) ) ); + c0_32 = _mm256_loadu_epi8( b + ( ldb * ( k_full_pieces + 2 ) ) ); + d0_32 = _mm256_setzero_si256(); + + } + else if( k_partial_pieces == 2 ) + { + a0_32 = _mm256_loadu_epi8( b + ( ldb * ( k_full_pieces + 0 ) ) ); + b0_32 = _mm256_loadu_epi8( b + ( ldb * ( k_full_pieces + 1 ) ) ); + c0_32 = _mm256_setzero_si256(); + d0_32 = _mm256_setzero_si256(); + } + else //k_partial_pieces == 1 + { + a0_32 = _mm256_loadu_epi8( b + ( ldb * ( k_full_pieces + 0 ) ) ); + b0_32 = _mm256_setzero_si256(); + c0_32 = _mm256_setzero_si256(); + d0_32 = _mm256_setzero_si256(); + } + + a01_32 = _mm256_unpacklo_epi8( a0_32, b0_32 ); + a0_32 = _mm256_unpackhi_epi8( a0_32, b0_32 ); + + c01_32 = _mm256_unpacklo_epi8( c0_32, d0_32 ); + c0_32 = _mm256_unpackhi_epi8( c0_32, d0_32 ); + + b0_32 = _mm256_unpacklo_epi16( a01_32, c01_32 ); + a01_32 = _mm256_unpackhi_epi16( a01_32, c01_32 ); + + d0_32 = _mm256_unpacklo_epi16( a0_32, c0_32 ); + c01_32 = _mm256_unpackhi_epi16( a0_32, c0_32 ); + + a0_32 = _mm256_shuffle_i32x4( b0_32, a01_32, 0x0 ); // 0 elem + c0_32 = _mm256_shuffle_i32x4( b0_32, a01_32, 0x3 ); // 2 elem + b0_32 = _mm256_shuffle_i32x4( d0_32, c01_32, 0x0 ); // 1 elem + d0_32 = _mm256_shuffle_i32x4( d0_32, c01_32, 0x3 ); // 3 elem + + a0_zmm = _mm512_castsi256_si512( a0_32 ); + a0_zmm = _mm512_inserti32x8( a0_zmm, b0_32, 0x1 ); + b0_zmm = _mm512_castsi256_si512( c0_32 ); + b0_zmm = _mm512_inserti32x8( b0_zmm, d0_32, 0x1 ); + + // First 4x32 elements. + _mm512_storeu_epi64( pack_b_buffer_u8s8s32o32 + ( ( kr_new + 0 ) * NR ), a0_zmm ); + _mm512_storeu_epi64( pack_b_buffer_u8s8s32o32 + ( ( kr_new + 1 ) * NR ), b0_zmm ); + } +} + +void packb_nr16_u8s8s32o32 + ( + int8_t* pack_b_buffer_u8s8s32o32, + const int8_t* b, + const dim_t ldb, + const dim_t KC + ) +{ + dim_t kr_new = 0; + + dim_t k_full_pieces_blks = KC / 4; + dim_t k_full_pieces = k_full_pieces_blks * 4; + dim_t k_partial_pieces = KC % 4; + + __m128i a0_16; + __m128i b0_16; + __m128i c0_16; + __m128i d0_16; + __m128i a01_16; + __m128i c01_16; + __m512i a0_zmm; + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 4 ) + { + // Rearrange for vpdpbusd, read 4 rows from B with next 16 elements in each row. + a0_16 = _mm_loadu_epi8( b + ( ldb * ( kr + 0 ) ) ); + b0_16 = _mm_loadu_epi8( b + ( ldb * ( kr + 1 ) ) ); + c0_16 = _mm_loadu_epi8( b + ( ldb * ( kr + 2 ) ) ); + d0_16 = _mm_loadu_epi8( b + ( ldb * ( kr + 3 ) ) ); + + a01_16 = _mm_unpacklo_epi8( a0_16, b0_16 ); + a0_16 = _mm_unpackhi_epi8( a0_16, b0_16 ); + + c01_16 = _mm_unpacklo_epi8( c0_16, d0_16 ); + c0_16 = _mm_unpackhi_epi8( c0_16, d0_16 ); + + b0_16 = _mm_unpacklo_epi16( a01_16, c01_16 ); // 0 elem + a01_16 = _mm_unpackhi_epi16( a01_16, c01_16 ); // 1 elem + d0_16 = _mm_unpacklo_epi16( a0_16, c0_16 ); // 2 elem + c01_16 = _mm_unpackhi_epi16( a0_16, c0_16 ); // 3 elem + + a0_zmm = _mm512_castsi128_si512( b0_16 ); + a0_zmm = _mm512_inserti32x4( a0_zmm, a01_16, 0x1 ); + a0_zmm = _mm512_inserti32x4( a0_zmm, d0_16, 0x2 ); + a0_zmm = _mm512_inserti32x4( a0_zmm, c01_16, 0x3 ); + + // Last 4x16 elements. + _mm512_storeu_epi64( pack_b_buffer_u8s8s32o32 + ( ( kr_new + 0 ) * NR ), a0_zmm ); + + // The 2nd, 3rd, and 4th 16byte chunk will be ignored, since its not part of the original data, + // but is here due to the packing in 4 16byte chunks format. + kr_new += 1; + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + if ( k_partial_pieces == 3 ) + { + a0_16 = _mm_loadu_epi8( b + ( ldb * ( k_full_pieces + 0 ) ) ); + b0_16 = _mm_loadu_epi8( b + ( ldb * ( k_full_pieces + 1 ) ) ); + c0_16 = _mm_loadu_epi8( b + ( ldb * ( k_full_pieces + 2 ) ) ); + d0_16 = _mm_setzero_si128(); + + } + else if( k_partial_pieces == 2 ) + { + a0_16 = _mm_loadu_epi8( b + ( ldb * ( k_full_pieces + 0 ) ) ); + b0_16 = _mm_loadu_epi8( b + ( ldb * ( k_full_pieces + 1 ) ) ); + c0_16 = _mm_setzero_si128(); + d0_16 = _mm_setzero_si128(); + } + else //k_partial_pieces == 1 + { + a0_16 = _mm_loadu_epi8( b + ( ldb * ( k_full_pieces + 0 ) ) ); + b0_16 = _mm_setzero_si128(); + c0_16 = _mm_setzero_si128(); + d0_16 = _mm_setzero_si128(); + } + + a01_16 = _mm_unpacklo_epi8( a0_16, b0_16 ); + a0_16 = _mm_unpackhi_epi8( a0_16, b0_16 ); + + c01_16 = _mm_unpacklo_epi8( c0_16, d0_16 ); + c0_16 = _mm_unpackhi_epi8( c0_16, d0_16 ); + + b0_16 = _mm_unpacklo_epi16( a01_16, c01_16 ); // 0 elem + a01_16 = _mm_unpackhi_epi16( a01_16, c01_16 ); // 1 elem + d0_16 = _mm_unpacklo_epi16( a0_16, c0_16 ); // 2 elem + c01_16 = _mm_unpackhi_epi16( a0_16, c0_16 ); // 3 elem + + __m512i a0_zmm = _mm512_castsi128_si512( b0_16 ); + a0_zmm = _mm512_inserti32x4( a0_zmm, a01_16, 0x1 ); + a0_zmm = _mm512_inserti32x4( a0_zmm, d0_16, 0x2 ); + a0_zmm = _mm512_inserti32x4( a0_zmm, c01_16, 0x3 ); + + // Last 4x16 elements. + _mm512_storeu_epi64( pack_b_buffer_u8s8s32o32 + ( ( kr_new + 0 ) * NR ), a0_zmm ); + } +} + +void packb_nrlt16_u8s8s32o32 + ( + int8_t* pack_b_buffer_u8s8s32o32, + const int8_t* b, + const dim_t ldb, + const dim_t KC, + const dim_t n0_partial_rem + ) +{ + int8_t buf0[16]; + int8_t buf1[16]; + int8_t buf2[16]; + int8_t buf3[16]; + + dim_t kr_new = 0; + + dim_t k_full_pieces_blks = KC / 4; + dim_t k_full_pieces = k_full_pieces_blks * 4; + dim_t k_partial_pieces = KC % 4; + + __m128i a0_16; + __m128i b0_16; + __m128i c0_16; + __m128i d0_16; + __m128i a01_16; + __m128i c01_16; + __m512i a0_zmm; + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 4 ) + { + memcpy( buf0, ( b + ( ldb * ( kr + 0 ) ) ), ( n0_partial_rem * sizeof( int8_t ) ) ); + memcpy( buf1, ( b + ( ldb * ( kr + 1 ) ) ), ( n0_partial_rem * sizeof( int8_t ) ) ); + memcpy( buf2, ( b + ( ldb * ( kr + 2 ) ) ), ( n0_partial_rem * sizeof( int8_t ) ) ); + memcpy( buf3, ( b + ( ldb * ( kr + 3 ) ) ), ( n0_partial_rem * sizeof( int8_t ) ) ); + + // Rearrange for vpdpbusd, read 4 rows from B with next 16 elements in each row. + a0_16 = _mm_loadu_epi8( buf0 ); + b0_16 = _mm_loadu_epi8( buf1 ); + c0_16 = _mm_loadu_epi8( buf2 ); + d0_16 = _mm_loadu_epi8( buf3 ); + + a01_16 = _mm_unpacklo_epi8( a0_16, b0_16 ); + a0_16 = _mm_unpackhi_epi8( a0_16, b0_16 ); + + c01_16 = _mm_unpacklo_epi8( c0_16, d0_16 ); + c0_16 = _mm_unpackhi_epi8( c0_16, d0_16 ); + + b0_16 = _mm_unpacklo_epi16( a01_16, c01_16 ); // 0 elem + a01_16 = _mm_unpackhi_epi16( a01_16, c01_16 ); // 1 elem + d0_16 = _mm_unpacklo_epi16( a0_16, c0_16 ); // 2 elem + c01_16 = _mm_unpackhi_epi16( a0_16, c0_16 ); // 3 elem + + a0_zmm = _mm512_castsi128_si512( b0_16 ); + a0_zmm = _mm512_inserti32x4( a0_zmm, a01_16, 0x1 ); + a0_zmm = _mm512_inserti32x4( a0_zmm, d0_16, 0x2 ); + a0_zmm = _mm512_inserti32x4( a0_zmm, c01_16, 0x3 ); + + // Last 4x16 elements. + _mm512_storeu_epi64( pack_b_buffer_u8s8s32o32 + ( ( kr_new + 0 ) * NR ), a0_zmm ); + + // The 2nd, 3rd, and 4th 16byte chunk will be ignored, since its not part of the original data, + // but is here due to the packing in 4 16byte chunks format. + kr_new += 1; + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + if ( k_partial_pieces == 3 ) + { + memcpy( buf0, ( b + ( ldb * ( k_full_pieces + 0 ) ) ), ( n0_partial_rem * sizeof( int8_t ) ) ); + memcpy( buf1, ( b + ( ldb * ( k_full_pieces + 1 ) ) ), ( n0_partial_rem * sizeof( int8_t ) ) ); + memcpy( buf2, ( b + ( ldb * ( k_full_pieces + 2 ) ) ), ( n0_partial_rem * sizeof( int8_t ) ) ); + + a0_16 = _mm_loadu_epi8( buf0 ); + b0_16 = _mm_loadu_epi8( buf1 ); + c0_16 = _mm_loadu_epi8( buf2 ); + d0_16 = _mm_setzero_si128(); + + } + else if( k_partial_pieces == 2 ) + { + memcpy( buf0, ( b + ( ldb * ( k_full_pieces + 0 ) ) ), ( n0_partial_rem * sizeof( int8_t ) ) ); + memcpy( buf1, ( b + ( ldb * ( k_full_pieces + 1 ) ) ), ( n0_partial_rem * sizeof( int8_t ) ) ); + + a0_16 = _mm_loadu_epi8( buf0 ); + b0_16 = _mm_loadu_epi8( buf1 ); + c0_16 = _mm_setzero_si128(); + d0_16 = _mm_setzero_si128(); + } + else //k_partial_pieces == 1 + { + memcpy( buf0, ( b + ( ldb * ( k_full_pieces + 0 ) ) ), ( n0_partial_rem * sizeof( int8_t ) ) ); + + a0_16 = _mm_loadu_epi8( buf0 ); + b0_16 = _mm_setzero_si128(); + c0_16 = _mm_setzero_si128(); + d0_16 = _mm_setzero_si128(); + } + + a01_16 = _mm_unpacklo_epi8( a0_16, b0_16 ); + a0_16 = _mm_unpackhi_epi8( a0_16, b0_16 ); + + c01_16 = _mm_unpacklo_epi8( c0_16, d0_16 ); + c0_16 = _mm_unpackhi_epi8( c0_16, d0_16 ); + + b0_16 = _mm_unpacklo_epi16( a01_16, c01_16 ); // 0 elem + a01_16 = _mm_unpackhi_epi16( a01_16, c01_16 ); // 1 elem + d0_16 = _mm_unpacklo_epi16( a0_16, c0_16 ); // 2 elem + c01_16 = _mm_unpackhi_epi16( a0_16, c0_16 ); // 3 elem + + __m512i a0_zmm = _mm512_castsi128_si512( b0_16 ); + a0_zmm = _mm512_inserti32x4( a0_zmm, a01_16, 0x1 ); + a0_zmm = _mm512_inserti32x4( a0_zmm, d0_16, 0x2 ); + a0_zmm = _mm512_inserti32x4( a0_zmm, c01_16, 0x3 ); + + // Last 4x16 elements. + _mm512_storeu_epi64( pack_b_buffer_u8s8s32o32 + ( ( kr_new + 0 ) * NR ), a0_zmm ); + } +} +#endif diff --git a/addon/aocl_gemm/kernels/u8s8s32/lpgemm_s32_kern_macros.h b/addon/aocl_gemm/kernels/u8s8s32/lpgemm_s32_kern_macros.h new file mode 100644 index 0000000000..bc3546736c --- /dev/null +++ b/addon/aocl_gemm/kernels/u8s8s32/lpgemm_s32_kern_macros.h @@ -0,0 +1,103 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef LPGEMM_S32_KERN_MACROS_H +#define LPGEMM_S32_KERN_MACROS_H +#define S8_MIN (-128) +#define S8_MAX (+127) + +#define RELU_SCALE_OP_S32_AVX512(reg) \ + /* Generate indenx of elements <= 0.*/ \ + relu_cmp_mask = _mm512_cmple_epi32_mask( reg, selector1 ); \ + \ + /* Apply scaling on for <= 0 elements.*/ \ + reg = _mm512_mask_mullo_epi32( reg, relu_cmp_mask, reg, selector2 ); \ + +#define CVT_MULRND_CVT32_CVT8(reg,selector,m_ind,n_ind) \ + _mm_storeu_epi8 \ + ( \ + ( int8_t* )post_ops_list_temp->op_args3 + \ + ( rs_c_downscale * ( post_op_c_i + m_ind ) ) + post_op_c_j + ( n_ind * 16 ), \ + _mm512_cvtepi32_epi8 \ + ( \ + _mm512_cvtps_epi32 \ + ( \ + _mm512_min_ps \ + ( \ + _mm512_max_ps \ + ( \ + _mm512_mul_round_ps \ + ( \ + _mm512_cvtepi32_ps( reg ), \ + ( __m512 )selector, \ + ( _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC ) \ + ) \ + , _mm512_set1_ps (( float )S8_MIN) \ + ) \ + , _mm512_set1_ps (( float )S8_MAX) \ + ) \ + ) \ + ) \ + ) \ + +#define CVT_MULRND_CVT32_CVT8_LT16(reg,selector,m_ind,n_ind) \ + _mm_storeu_epi8 \ + ( \ + buf0, \ + _mm512_cvtepi32_epi8 \ + ( \ + _mm512_cvtps_epi32 \ + ( \ + _mm512_min_ps \ + ( \ + _mm512_max_ps \ + ( \ + _mm512_mul_round_ps \ + ( \ + _mm512_cvtepi32_ps( reg ), \ + ( __m512 )selector, \ + ( _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC ) \ + ) \ + , _mm512_set1_ps (( float )S8_MIN) \ + ) \ + , _mm512_set1_ps (( float )S8_MAX) \ + ) \ + ) \ + ) \ + ); \ + memcpy( ( int8_t* )post_ops_list_temp->op_args3 + \ + ( rs_c_downscale * ( post_op_c_i + m_ind ) ) + post_op_c_j + \ + ( n_ind * 16 ) , buf0, ( n0_rem * sizeof( int8_t ) ) ); \ + +#endif // LPGEMM_S32_KERN_MACROS_H diff --git a/addon/gemmd/attic/bli_gemm_ex.c b/addon/gemmd/attic/bli_gemm_ex.c new file mode 100644 index 0000000000..0f40d1cb39 --- /dev/null +++ b/addon/gemmd/attic/bli_gemm_ex.c @@ -0,0 +1,88 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +void bli_gemm_ex + ( + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + rntm_t* rntm + ) +{ + bli_init_once(); + + // A switch to easily toggle whether we use the addon implementation + // of bao_gemmd() as the implementation for bli_gemm(). (This allows for + // easy testing of bao_gemmd() via the testsuite.) + if ( 1 ) + { + const dim_t k = bli_obj_width_after_trans( a ); + const num_t dt = bli_obj_dt( c ); + obj_t d; + + bli_obj_create( dt, k, 1, 1, k, &d ); + bli_setv( &BLIS_ONE, &d ); + //bli_randv( &d ); + + bao_gemmd_ex( alpha, a, &d, b, beta, c, cntx, rntm ); + + bli_obj_free( &d ); + return; + } + + // Initialize a local runtime with global settings if necessary. Note + // that in the case that a runtime is passed in, we make a local copy. + rntm_t rntm_l; + if ( rntm == NULL ) { bli_rntm_init_from_global( &rntm_l ); rntm = &rntm_l; } + else { rntm_l = *rntm; rntm = &rntm_l; } + + // Obtain a valid (native) context from the gks if necessary. + if ( cntx == NULL ) cntx = bli_gks_query_cntx(); + + // Check the operands. + if ( bli_error_checking_is_enabled() ) + bli_gemm_check( alpha, a, b, beta, c, cntx ); + + // Invoke the operation's front end. + bli_gemm_front + ( + alpha, a, b, beta, c, cntx, rntm, NULL + ); +} + diff --git a/addon/gemmd/bao_gemmd.c b/addon/gemmd/bao_gemmd.c new file mode 100644 index 0000000000..71d49806ba --- /dev/null +++ b/addon/gemmd/bao_gemmd.c @@ -0,0 +1,305 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +// +// -- Define the gemmd operation's object API ---------------------------------- +// + +void bao_gemmd + ( + obj_t* alpha, + obj_t* a, + obj_t* d, + obj_t* b, + obj_t* beta, + obj_t* c + ) +{ + bao_gemmd_ex + ( + alpha, + a, + d, + b, + beta, + c, + NULL, + NULL + ); +} + +void bao_gemmd_ex + ( + obj_t* alpha, + obj_t* a, + obj_t* d, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + rntm_t* rntm + ) +{ + bli_init_once(); + + // Initialize a local runtime with global settings if necessary. Note + // that in the case that a runtime is passed in, we make a local copy. + rntm_t rntm_l; + if ( rntm == NULL ) { bli_rntm_init_from_global( &rntm_l ); rntm = &rntm_l; } + else { rntm_l = *rntm; rntm = &rntm_l; } + + // Obtain a valid (native) context from the gks if necessary. + // NOTE: This must be done before calling the _check() function, since + // that function assumes the context pointer is valid. + if ( cntx == NULL ) cntx = bli_gks_query_cntx(); + + // Check parameters. + if ( bli_error_checking_is_enabled() ) + bao_gemmd_check( alpha, a, d, b, beta, c, cntx ); + + // -- bli_gemmd_front() ---------------------------------------------------- + + obj_t a_local; + obj_t b_local; + obj_t c_local; + + // If C has a zero dimension, return early. + if ( bli_obj_has_zero_dim( c ) ) + { + return; + } + + // If alpha is zero, or if A or B has a zero dimension, scale C by beta + // and return early. + if ( bli_obj_equals( alpha, &BLIS_ZERO ) || + bli_obj_has_zero_dim( a ) || + bli_obj_has_zero_dim( b ) ) + { + bli_scalm( beta, c ); + return; + } + + // Alias A, B, and C in case we need to apply transformations. + bli_obj_alias_to( a, &a_local ); + bli_obj_alias_to( b, &b_local ); + bli_obj_alias_to( c, &c_local ); + + // Induce a transposition of A if it has its transposition property set. + // Then clear the transposition bit in the object. + if ( bli_obj_has_trans( &a_local ) ) + { + bli_obj_induce_trans( &a_local ); + bli_obj_set_onlytrans( BLIS_NO_TRANSPOSE, &a_local ); + } + + // Induce a transposition of B if it has its transposition property set. + // Then clear the transposition bit in the object. + if ( bli_obj_has_trans( &b_local ) ) + { + bli_obj_induce_trans( &b_local ); + bli_obj_set_onlytrans( BLIS_NO_TRANSPOSE, &b_local ); + } + + // An optimization: If C is stored by rows and the micro-kernel prefers + // contiguous columns, or if C is stored by columns and the micro-kernel + // prefers contiguous rows, transpose the entire operation to allow the + // micro-kernel to access elements of C in its preferred manner. + if ( bli_cntx_l3_vir_ukr_dislikes_storage_of( &c_local, BLIS_GEMM_UKR, cntx ) ) + { + bli_obj_swap( &a_local, &b_local ); + + bli_obj_induce_trans( &a_local ); + bli_obj_induce_trans( &b_local ); + bli_obj_induce_trans( &c_local ); + } + + // Parse and interpret the contents of the rntm_t object to properly + // set the ways of parallelism for each loop, and then make any + // additional modifications necessary for the current operation. + bli_rntm_set_ways_for_op + ( + BLIS_GEMM, + BLIS_LEFT, // ignored for gemm/hemm/symm + bli_obj_length( &c_local ), + bli_obj_width( &c_local ), + bli_obj_width( &a_local ), + rntm + ); + + // Spawn threads (if applicable), where bao_gemmd_int() is the thread entry + // point function for each thread. This also begins the process of creating + // the thrinfo_t tree, which contains thread communicators. + bao_l3_thread_decorator + ( + bao_gemmd_int, + BLIS_GEMM, // operation family id + alpha, + &a_local, + d, + &b_local, + beta, + &c_local, + cntx, + rntm + ); +} + +// +// -- Define the gemmd operation's thread entry point -------------------------- +// + +void bao_gemmd_int + ( + obj_t* alpha, + obj_t* a, + obj_t* d, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + rntm_t* rntm, + thrinfo_t* thread + ) +{ + // In this function, we choose the gemmd implementation that is executed + // on each thread. + +#if 1 + // Call the block-panel algorithm that calls the kernel directly, which + // exposes edge-case handling. + bao_gemmd_bp_var1 + ( + alpha, + a, + d, + b, + beta, + c, + cntx, + rntm, + thread + ); +#else + // Call the block-panel algorithm that calls the kernel indirectly via a + // wrapper function, which hides edge-case handling. + bao_gemmd_bp_var2 + ( + alpha, + a, + d, + b, + beta, + c, + cntx, + rntm, + thread + ); +#endif +} + +// +// -- Define the gemmd operation's typed API ----------------------------------- +// + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, opname ) \ +\ +void PASTECH2(bao_,ch,opname) \ + ( \ + trans_t transa, \ + trans_t transb, \ + dim_t m, \ + dim_t n, \ + dim_t k, \ + ctype* alpha, \ + ctype* a, inc_t rs_a, inc_t cs_a, \ + ctype* d, inc_t incd, \ + ctype* b, inc_t rs_b, inc_t cs_b, \ + ctype* beta, \ + ctype* c, inc_t rs_c, inc_t cs_c \ + ) \ +{ \ + bli_init_once(); \ +\ + /* Determine the datatype (e.g. BLIS_FLOAT, BLIS_DOUBLE, etc.) based on + the macro parameter 'ch' (e.g. s, d, etc). */ \ + const num_t dt = PASTEMAC(ch,type); \ +\ + obj_t alphao, ao, dd, bo, betao, co; \ +\ + dim_t m_a, n_a; \ + dim_t m_b, n_b; \ +\ + /* Adjust the dimensions of matrices A and B according to the transa and + transb parameters. */ \ + bli_set_dims_with_trans( transa, m, k, &m_a, &n_a ); \ + bli_set_dims_with_trans( transb, k, n, &m_b, &n_b ); \ +\ + /* Create bufferless scalar objects and attach the provided scalar pointers + to those scalar objects. */ \ + bli_obj_create_1x1_with_attached_buffer( dt, alpha, &alphao ); \ + bli_obj_create_1x1_with_attached_buffer( dt, beta, &betao ); \ +\ + /* Create bufferless matrix objects and attach the provided matrix pointers + to those matrix objects. */ \ + bli_obj_create_with_attached_buffer( dt, m_a, n_a, a, rs_a, cs_a, &ao ); \ + bli_obj_create_with_attached_buffer( dt, k, 1, d, incd, k, &dd ); \ + bli_obj_create_with_attached_buffer( dt, m_b, n_b, b, rs_b, cs_b, &bo ); \ + bli_obj_create_with_attached_buffer( dt, m, n, c, rs_c, cs_c, &co ); \ +\ + /* Set the transposition/conjugation properties of the objects for matrices + A and B. */ \ + bli_obj_set_conjtrans( transa, &ao ); \ + bli_obj_set_conjtrans( transb, &bo ); \ +\ + /* Call the object interface. */ \ + PASTECH(bao_,opname) \ + ( \ + &alphao, \ + &ao, \ + &dd, \ + &bo, \ + &betao, \ + &co \ + ); \ +} + +//INSERT_GENTFUNC_BASIC0( gemmd ) +GENTFUNC( float, s, gemmd ) +GENTFUNC( double, d, gemmd ) +GENTFUNC( scomplex, c, gemmd ) +GENTFUNC( dcomplex, z, gemmd ) + diff --git a/addon/gemmd/bao_gemmd.h b/addon/gemmd/bao_gemmd.h new file mode 100644 index 0000000000..7c7466494d --- /dev/null +++ b/addon/gemmd/bao_gemmd.h @@ -0,0 +1,105 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +// +// -- Prototype the gemmd operation's object API ------------------------------- +// + +BLIS_EXPORT_ADDON void bao_gemmd + ( + obj_t* alpha, + obj_t* a, + obj_t* d, + obj_t* b, + obj_t* beta, + obj_t* c + ); + +BLIS_EXPORT_ADDON void bao_gemmd_ex + ( + obj_t* alpha, + obj_t* a, + obj_t* d, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + rntm_t* rntm + ); + +// +// -- Prototype the gemmd operation's thread entry point ----------------------- +// + +void bao_gemmd_int + ( + obj_t* alpha, + obj_t* a, + obj_t* d, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + rntm_t* rntm, + thrinfo_t* thread + ); + +// +// -- Prototype the gemmd operation's typed API -------------------------------- +// + +#undef GENTPROT +#define GENTPROT( ctype, ch, opname ) \ +\ +BLIS_EXPORT_ADDON void PASTECH2(bao_,ch,opname) \ + ( \ + trans_t transa, \ + trans_t transb, \ + dim_t m, \ + dim_t n, \ + dim_t k, \ + ctype* alpha, \ + ctype* a, inc_t rs_a, inc_t cs_a, \ + ctype* d, inc_t incd, \ + ctype* b, inc_t rs_b, inc_t cs_b, \ + ctype* beta, \ + ctype* c, inc_t rs_c, inc_t cs_c \ + ); + +//INSERT_GENTPROT_BASIC0( gemmd ) +GENTPROT( float, s, gemmd ) +GENTPROT( double, d, gemmd ) +GENTPROT( scomplex, c, gemmd ) +GENTPROT( dcomplex, z, gemmd ) + diff --git a/addon/gemmd/bao_gemmd_bp_var1.c b/addon/gemmd/bao_gemmd_bp_var1.c new file mode 100644 index 0000000000..e042f1fd81 --- /dev/null +++ b/addon/gemmd/bao_gemmd_bp_var1.c @@ -0,0 +1,530 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#define FUNCPTR_T gemmd_fp + +typedef void (*FUNCPTR_T) + ( + conj_t conja, + conj_t conjb, + dim_t m, + dim_t n, + dim_t k, + void* restrict alpha, + void* restrict a, inc_t rs_a, inc_t cs_a, + void* restrict d, inc_t incd, + void* restrict b, inc_t rs_b, inc_t cs_b, + void* restrict beta, + void* restrict c, inc_t rs_c, inc_t cs_c, + cntx_t* restrict cntx, + rntm_t* restrict rntm, + thrinfo_t* restrict thread + ); + +// +// -- gemmd-like block-panel algorithm (object interface) ---------------------- +// + +// Define a function pointer array named ftypes and initialize its contents with +// the addresses of the typed functions defined below, bao_?gemmd_bp_var1(). +static FUNCPTR_T GENARRAY_PREF(ftypes,bao_,gemmd_bp_var1); + +void bao_gemmd_bp_var1 + ( + obj_t* alpha, + obj_t* a, + obj_t* d, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + rntm_t* rntm, + thrinfo_t* thread + ) +{ + const num_t dt = bli_obj_dt( c ); + + const conj_t conja = bli_obj_conj_status( a ); + const conj_t conjb = bli_obj_conj_status( b ); + + const dim_t m = bli_obj_length( c ); + const dim_t n = bli_obj_width( c ); + const dim_t k = bli_obj_width( a ); + + void* restrict buf_a = bli_obj_buffer_at_off( a ); + const inc_t rs_a = bli_obj_row_stride( a ); + const inc_t cs_a = bli_obj_col_stride( a ); + + void* restrict buf_d = bli_obj_buffer_at_off( d ); + const inc_t incd = bli_obj_vector_inc( d ); + + void* restrict buf_b = bli_obj_buffer_at_off( b ); + const inc_t rs_b = bli_obj_row_stride( b ); + const inc_t cs_b = bli_obj_col_stride( b ); + + void* restrict buf_c = bli_obj_buffer_at_off( c ); + const inc_t rs_c = bli_obj_row_stride( c ); + const inc_t cs_c = bli_obj_col_stride( c ); + + void* restrict buf_alpha = bli_obj_buffer_for_1x1( dt, alpha ); + void* restrict buf_beta = bli_obj_buffer_for_1x1( dt, beta ); + + // Index into the function pointer array to extract the correct + // typed function pointer based on the chosen datatype. + FUNCPTR_T f = ftypes[dt]; + + // Invoke the function. + f + ( + conja, + conjb, + m, + n, + k, + buf_alpha, + buf_a, rs_a, cs_a, + buf_d, incd, + buf_b, rs_b, cs_b, + buf_beta, + buf_c, rs_c, cs_c, + cntx, + rntm, + thread + ); +} + +// +// -- gemmd-like block-panel algorithm (typed interface) ----------------------- +// + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, varname ) \ +\ +void PASTECH2(bao_,ch,varname) \ + ( \ + conj_t conja, \ + conj_t conjb, \ + dim_t m, \ + dim_t n, \ + dim_t k, \ + void* restrict alpha, \ + void* restrict a, inc_t rs_a, inc_t cs_a, \ + void* restrict d, inc_t incd, \ + void* restrict b, inc_t rs_b, inc_t cs_b, \ + void* restrict beta, \ + void* restrict c, inc_t rs_c, inc_t cs_c, \ + cntx_t* restrict cntx, \ + rntm_t* restrict rntm, \ + thrinfo_t* restrict thread \ + ) \ +{ \ + const num_t dt = PASTEMAC(ch,type); \ +\ + /* Query the context for various blocksizes. */ \ + const dim_t NR = bli_cntx_get_blksz_def_dt( dt, BLIS_NR, cntx ); \ + const dim_t MR = bli_cntx_get_blksz_def_dt( dt, BLIS_MR, cntx ); \ + const dim_t NC = bli_cntx_get_blksz_def_dt( dt, BLIS_NC, cntx ); \ + const dim_t MC = bli_cntx_get_blksz_def_dt( dt, BLIS_MC, cntx ); \ + const dim_t KC = bli_cntx_get_blksz_def_dt( dt, BLIS_KC, cntx ); \ +\ + /* Query the context for the microkernel address and cast it to its + function pointer type. */ \ + PASTECH(ch,gemm_ukr_ft) \ + gemm_ukr = bli_cntx_get_l3_nat_ukr_dt( dt, BLIS_GEMM_UKR, cntx ); \ +\ + /* Temporary C buffer for edge cases. Note that the strides of this + temporary buffer are set so that they match the storage of the + original C matrix. For example, if C is column-stored, ct will be + column-stored as well. */ \ + ctype ct[ BLIS_STACK_BUF_MAX_SIZE \ + / sizeof( ctype ) ] \ + __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \ + const bool col_pref = bli_cntx_l3_nat_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, cntx ); \ + const inc_t rs_ct = ( col_pref ? 1 : NR ); \ + const inc_t cs_ct = ( col_pref ? MR : 1 ); \ +\ + /* Compute partitioning step values for each matrix of each loop. */ \ + const inc_t jcstep_c = cs_c; \ + const inc_t jcstep_b = cs_b; \ +\ + const inc_t pcstep_a = cs_a; \ + const inc_t pcstep_d = incd; \ + const inc_t pcstep_b = rs_b; \ +\ + const inc_t icstep_c = rs_c; \ + const inc_t icstep_a = rs_a; \ +\ + const inc_t jrstep_c = cs_c * NR; \ +\ + const inc_t irstep_c = rs_c * MR; \ +\ + ctype* restrict a_00 = a; \ + ctype* restrict d_00 = d; \ + ctype* restrict b_00 = b; \ + ctype* restrict c_00 = c; \ + ctype* restrict alpha_cast = alpha; \ + ctype* restrict beta_cast = beta; \ +\ + /* Make local copies of the scalars to prevent any unnecessary sharing of + cache lines between the cores' caches. */ \ + ctype alpha_local = *alpha_cast; \ + ctype beta_local = *beta_cast; \ + ctype one_local = *PASTEMAC(ch,1); \ + ctype zero_local = *PASTEMAC(ch,0); \ +\ + auxinfo_t aux; \ +\ + /* Initialize a mem_t entry for A and B. Strictly speaking, this is only + needed for the matrix we will be packing (if any), but we do it + unconditionally to be safe. */ \ + mem_t mem_a = BLIS_MEM_INITIALIZER; \ + mem_t mem_b = BLIS_MEM_INITIALIZER; \ +\ + /* Define an array of bszid_t ids, which will act as our substitute for + the cntl_t tree. */ \ + bszid_t bszids[8] = { BLIS_NC, /* 5th loop */ \ + BLIS_KC, /* 4th loop */ \ + BLIS_NO_PART, /* pack B */ \ + BLIS_MC, /* 3rd loop */ \ + BLIS_NO_PART, /* pack A */ \ + BLIS_NR, /* 2nd loop */ \ + BLIS_MR, /* 1st loop */ \ + BLIS_KR }; /* microkernel loop */ \ +\ + bszid_t* restrict bszids_jc = &bszids[0]; \ + bszid_t* restrict bszids_pc = &bszids[1]; \ + /*bszid_t* restrict bszids_pb = &bszids[2];*/ \ + bszid_t* restrict bszids_ic = &bszids[3]; \ + /*bszid_t* restrict bszids_pa = &bszids[4];*/ \ + bszid_t* restrict bszids_jr = &bszids[5]; \ + /*bszid_t* restrict bszids_ir = &bszids[6];*/ \ +\ + thrinfo_t* restrict thread_jc = NULL; \ + thrinfo_t* restrict thread_pc = NULL; \ + thrinfo_t* restrict thread_pb = NULL; \ + thrinfo_t* restrict thread_ic = NULL; \ + thrinfo_t* restrict thread_pa = NULL; \ + thrinfo_t* restrict thread_jr = NULL; \ + thrinfo_t* restrict thread_ir = NULL; \ +\ + /* Identify the current thrinfo_t node and then grow the tree. */ \ + thread_jc = thread; \ + bli_thrinfo_sup_grow( rntm, bszids_jc, thread_jc ); \ +\ + /* Compute the JC loop thread range for the current thread. */ \ + dim_t jc_start, jc_end; \ + bli_thread_range_sub( thread_jc, n, NR, FALSE, &jc_start, &jc_end ); \ + const dim_t n_local = jc_end - jc_start; \ +\ + /* Compute number of primary and leftover components of the JC loop. */ \ + /*const dim_t jc_iter = ( n_local + NC - 1 ) / NC;*/ \ + const dim_t jc_left = n_local % NC; \ +\ + /* Loop over the n dimension (NC rows/columns at a time). */ \ + for ( dim_t jj = jc_start; jj < jc_end; jj += NC ) \ + { \ + /* Calculate the thread's current JC block dimension. */ \ + const dim_t nc_cur = ( NC <= jc_end - jj ? NC : jc_left ); \ +\ + ctype* restrict b_jc = b_00 + jj * jcstep_b; \ + ctype* restrict c_jc = c_00 + jj * jcstep_c; \ +\ + /* Identify the current thrinfo_t node and then grow the tree. */ \ + thread_pc = bli_thrinfo_sub_node( thread_jc ); \ + bli_thrinfo_sup_grow( rntm, bszids_pc, thread_pc ); \ +\ + /* Compute the PC loop thread range for the current thread. */ \ + const dim_t pc_start = 0, pc_end = k; \ + const dim_t k_local = k; \ +\ + /* Compute number of primary and leftover components of the PC loop. */ \ + /*const dim_t pc_iter = ( k_local + KC - 1 ) / KC;*/ \ + const dim_t pc_left = k_local % KC; \ +\ + /* Loop over the k dimension (KC rows/columns at a time). */ \ + for ( dim_t pp = pc_start; pp < pc_end; pp += KC ) \ + { \ + /* Calculate the thread's current PC block dimension. */ \ + const dim_t kc_cur = ( KC <= pc_end - pp ? KC : pc_left ); \ +\ + ctype* restrict a_pc = a_00 + pp * pcstep_a; \ + ctype* restrict d_pc = d_00 + pp * pcstep_d; \ + ctype* restrict b_pc = b_jc + pp * pcstep_b; \ +\ + /* Only apply beta to the first iteration of the pc loop. */ \ + ctype* restrict beta_use = ( pp == 0 ? &beta_local : &one_local ); \ +\ + ctype* b_use; \ + inc_t rs_b_use, cs_b_use, ps_b_use; \ +\ + /* Identify the current thrinfo_t node. Note that the thrinfo_t + node will have already been created by a previous call to + bli_thrinfo_sup_grow() since bszid_t values of BLIS_NO_PART + cause the tree to grow by two (e.g. to the next bszid that is + a normal bszid_t value). */ \ + thread_pb = bli_thrinfo_sub_node( thread_pc ); \ + /*bli_thrinfo_sup_grow( rntm, bszids_pb, thread_pb );*/ \ +\ + /* Determine the packing buffer and related parameters for matrix + B. Then call the packm implementation. */ \ + PASTECH2(bao_,ch,packm_b) \ + ( \ + conjb, \ + KC, NC, \ + kc_cur, nc_cur, NR, \ + &one_local, \ + d_pc, incd, \ + b_pc, rs_b, cs_b, \ + &b_use, &rs_b_use, &cs_b_use, \ + &ps_b_use, \ + cntx, \ + rntm, \ + &mem_b, \ + thread_pb \ + ); \ +\ + /* Alias b_use so that it's clear this is our current block of + matrix B. */ \ + ctype* restrict b_pc_use = b_use; \ +\ + /* Identify the current thrinfo_t node and then grow the tree. */ \ + thread_ic = bli_thrinfo_sub_node( thread_pb ); \ + bli_thrinfo_sup_grow( rntm, bszids_ic, thread_ic ); \ +\ + /* Compute the IC loop thread range for the current thread. */ \ + dim_t ic_start, ic_end; \ + bli_thread_range_sub( thread_ic, m, MR, FALSE, &ic_start, &ic_end ); \ + const dim_t m_local = ic_end - ic_start; \ +\ + /* Compute number of primary and leftover components of the IC loop. */ \ + /*const dim_t ic_iter = ( m_local + MC - 1 ) / MC;*/ \ + const dim_t ic_left = m_local % MC; \ +\ + /* Loop over the m dimension (MC rows at a time). */ \ + for ( dim_t ii = ic_start; ii < ic_end; ii += MC ) \ + { \ + /* Calculate the thread's current IC block dimension. */ \ + const dim_t mc_cur = ( MC <= ic_end - ii ? MC : ic_left ); \ +\ + ctype* restrict a_ic = a_pc + ii * icstep_a; \ + ctype* restrict c_ic = c_jc + ii * icstep_c; \ +\ + ctype* a_use; \ + inc_t rs_a_use, cs_a_use, ps_a_use; \ +\ + /* Identify the current thrinfo_t node. Note that the thrinfo_t + node will have already been created by a previous call to + bli_thrinfo_sup_grow() since bszid_t values of BLIS_NO_PART + cause the tree to grow by two (e.g. to the next bszid that is + a normal bszid_t value). */ \ + thread_pa = bli_thrinfo_sub_node( thread_ic ); \ + /*bli_thrinfo_sup_grow( rntm, bszids_pa, thread_pa );*/ \ +\ + /* Determine the packing buffer and related parameters for matrix + A. Then call the packm implementation. */ \ + PASTECH2(bao_,ch,packm_a) \ + ( \ + conja, \ + MC, KC, \ + mc_cur, kc_cur, MR, \ + &one_local, \ + d_pc, incd, \ + a_ic, rs_a, cs_a, \ + &a_use, &rs_a_use, &cs_a_use, \ + &ps_a_use, \ + cntx, \ + rntm, \ + &mem_a, \ + thread_pa \ + ); \ +\ + /* Alias a_use so that it's clear this is our current block of + matrix A. */ \ + ctype* restrict a_ic_use = a_use; \ +\ + /* Identify the current thrinfo_t node and then grow the tree. */ \ + thread_jr = bli_thrinfo_sub_node( thread_pa ); \ + bli_thrinfo_sup_grow( rntm, bszids_jr, thread_jr ); \ +\ + /* Query the number of threads and thread ids for the JR loop. + NOTE: These values are only needed when computing the next + micropanel of B. */ \ + const dim_t jr_nt = bli_thread_n_way( thread_jr ); \ + const dim_t jr_tid = bli_thread_work_id( thread_jr ); \ +\ + /* Compute number of primary and leftover components of the JR loop. */ \ + dim_t jr_iter = ( nc_cur + NR - 1 ) / NR; \ + dim_t jr_left = nc_cur % NR; \ +\ + /* Compute the JR loop thread range for the current thread. */ \ + dim_t jr_start, jr_end; \ + bli_thread_range_sub( thread_jr, jr_iter, 1, FALSE, &jr_start, &jr_end ); \ +\ + /* Loop over the n dimension (NR columns at a time). */ \ + for ( dim_t j = jr_start; j < jr_end; j += 1 ) \ + { \ + const dim_t nr_cur \ + = ( bli_is_not_edge_f( j, jr_iter, jr_left ) ? NR : jr_left ); \ +\ + ctype* restrict b_jr = b_pc_use + j * ps_b_use; \ + ctype* restrict c_jr = c_ic + j * jrstep_c; \ +\ + /* Assume for now that our next panel of B to be the current panel + of B. */ \ + ctype* restrict b2 = b_jr; \ +\ + /* Identify the current thrinfo_t node. */ \ + thread_ir = bli_thrinfo_sub_node( thread_jr ); \ +\ + /* Query the number of threads and thread ids for the IR loop. + NOTE: These values are only needed when computing the next + micropanel of A. */ \ + const dim_t ir_nt = bli_thread_n_way( thread_ir ); \ + const dim_t ir_tid = bli_thread_work_id( thread_ir ); \ +\ + /* Compute number of primary and leftover components of the IR loop. */ \ + dim_t ir_iter = ( mc_cur + MR - 1 ) / MR; \ + dim_t ir_left = mc_cur % MR; \ +\ + /* Compute the IR loop thread range for the current thread. */ \ + dim_t ir_start, ir_end; \ + bli_thread_range_sub( thread_ir, ir_iter, 1, FALSE, &ir_start, &ir_end ); \ +\ + /* Loop over the m dimension (MR rows at a time). */ \ + for ( dim_t i = ir_start; i < ir_end; i += 1 ) \ + { \ + const dim_t mr_cur \ + = ( bli_is_not_edge_f( i, ir_iter, ir_left ) ? MR : ir_left ); \ +\ + ctype* restrict a_ir = a_ic_use + i * ps_a_use; \ + ctype* restrict c_ir = c_jr + i * irstep_c; \ +\ + ctype* restrict a2; \ +\ + /* Compute the addresses of the next micropanels of A and B. */ \ + a2 = bli_gemm_get_next_a_upanel( a_ir, ps_a_use, 1 ); \ + if ( bli_is_last_iter( i, ir_end, ir_tid, ir_nt ) ) \ + { \ + a2 = a_ic_use; \ + b2 = bli_gemm_get_next_b_upanel( b_jr, ps_b_use, 1 ); \ + if ( bli_is_last_iter( j, jr_end, jr_tid, jr_nt ) ) \ + b2 = b_pc_use; \ + } \ +\ + /* Save the addresses of next micropanels of A and B to the + auxinfo_t object. */ \ + bli_auxinfo_set_next_a( a2, &aux ); \ + bli_auxinfo_set_next_b( b2, &aux ); \ +\ + /* Handle interior and edge cases separately. */ \ + if ( mr_cur == MR && nr_cur == NR ) \ + { \ + /* Invoke the gemm microkernel. */ \ + gemm_ukr \ + ( \ + kc_cur, \ + &alpha_local, \ + a_ir, \ + b_jr, \ + beta_use, \ + c_ir, rs_c, cs_c, \ + &aux, \ + cntx \ + ); \ + } \ + else \ + { \ + /* Invoke the gemm microkernel. */ \ + gemm_ukr \ + ( \ + kc_cur, \ + &alpha_local, \ + a_ir, \ + b_jr, \ + &zero_local, \ + ct, rs_ct, cs_ct, \ + &aux, \ + cntx \ + ); \ +\ + /* Scale the bottom edge of C and add the result from above. */ \ + PASTEMAC(ch,xpbys_mxn) \ + ( \ + mr_cur, \ + nr_cur, \ + ct, rs_ct, cs_ct, \ + beta_use, \ + c_ir, rs_c, cs_c \ + ); \ + } \ + } \ + } \ + } \ +\ + /* This barrier is needed to prevent threads from starting to pack + the next row panel of B before the current row panel is fully + computed upon. */ \ + bli_thread_barrier( thread_pb ); \ + } \ + } \ +\ + /* Release any memory that was acquired for packing matrices A and B. */ \ + PASTECH2(bao_,ch,packm_finalize_mem_a) \ + ( \ + rntm, \ + &mem_a, \ + thread_pa \ + ); \ + PASTECH2(bao_,ch,packm_finalize_mem_b) \ + ( \ + rntm, \ + &mem_b, \ + thread_pb \ + ); \ +\ +/* +PASTEMAC(ch,fprintm)( stdout, "gemmd_bp_var1: a1_packed", mr_cur, kc_cur, a_ir, rs_a_use, cs_a_use, "%5.2f", "" ); \ +PASTEMAC(ch,fprintm)( stdout, "gemmd_bp_var1: b1_packed", kc_cur, nr_cur, b_jr, rs_b_use, cs_b_use, "%5.2f", "" ); \ +PASTEMAC(ch,fprintm)( stdout, "gemmd_bp_var1: c ", mr_cur, nr_cur, c_ir, rs_c, cs_c, "%5.2f", "" ); \ +*/ \ +} + +//INSERT_GENTFUNC_BASIC0( gemmd_bp_var1 ) +GENTFUNC( float, s, gemmd_bp_var1 ) +GENTFUNC( double, d, gemmd_bp_var1 ) +GENTFUNC( scomplex, c, gemmd_bp_var1 ) +GENTFUNC( dcomplex, z, gemmd_bp_var1 ) + diff --git a/addon/gemmd/bao_gemmd_bp_var2.c b/addon/gemmd/bao_gemmd_bp_var2.c new file mode 100644 index 0000000000..a0040fec06 --- /dev/null +++ b/addon/gemmd/bao_gemmd_bp_var2.c @@ -0,0 +1,602 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#define FUNCPTR_T gemmd_fp + +typedef void (*FUNCPTR_T) + ( + conj_t conja, + conj_t conjb, + dim_t m, + dim_t n, + dim_t k, + void* restrict alpha, + void* restrict a, inc_t rs_a, inc_t cs_a, + void* restrict d, inc_t incd, + void* restrict b, inc_t rs_b, inc_t cs_b, + void* restrict beta, + void* restrict c, inc_t rs_c, inc_t cs_c, + cntx_t* restrict cntx, + rntm_t* restrict rntm, + thrinfo_t* restrict thread + ); + +// +// -- gemmd-like block-panel algorithm (object interface) ---------------------- +// + +// Define a function pointer array named ftypes and initialize its contents with +// the addresses of the typed functions defined below, bao_?gemmd_bp_var2(). +static FUNCPTR_T GENARRAY_PREF(ftypes,bao_,gemmd_bp_var2); + +void bao_gemmd_bp_var2 + ( + obj_t* alpha, + obj_t* a, + obj_t* d, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + rntm_t* rntm, + thrinfo_t* thread + ) +{ + const num_t dt = bli_obj_dt( c ); + + const conj_t conja = bli_obj_conj_status( a ); + const conj_t conjb = bli_obj_conj_status( b ); + + const dim_t m = bli_obj_length( c ); + const dim_t n = bli_obj_width( c ); + const dim_t k = bli_obj_width( a ); + + void* restrict buf_a = bli_obj_buffer_at_off( a ); + const inc_t rs_a = bli_obj_row_stride( a ); + const inc_t cs_a = bli_obj_col_stride( a ); + + void* restrict buf_d = bli_obj_buffer_at_off( d ); + const inc_t incd = bli_obj_vector_inc( d ); + + void* restrict buf_b = bli_obj_buffer_at_off( b ); + const inc_t rs_b = bli_obj_row_stride( b ); + const inc_t cs_b = bli_obj_col_stride( b ); + + void* restrict buf_c = bli_obj_buffer_at_off( c ); + const inc_t rs_c = bli_obj_row_stride( c ); + const inc_t cs_c = bli_obj_col_stride( c ); + + void* restrict buf_alpha = bli_obj_buffer_for_1x1( dt, alpha ); + void* restrict buf_beta = bli_obj_buffer_for_1x1( dt, beta ); + + // Index into the function pointer array to extract the correct + // typed function pointer based on the chosen datatype. + FUNCPTR_T f = ftypes[dt]; + + // Invoke the function. + f + ( + conja, + conjb, + m, + n, + k, + buf_alpha, + buf_a, rs_a, cs_a, + buf_d, incd, + buf_b, rs_b, cs_b, + buf_beta, + buf_c, rs_c, cs_c, + cntx, + rntm, + thread + ); +} + +// +// -- gemmd-like block-panel algorithm (typed interface) ----------------------- +// + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, varname ) \ +\ +void PASTECH2(bao_,ch,varname) \ + ( \ + conj_t conja, \ + conj_t conjb, \ + dim_t m, \ + dim_t n, \ + dim_t k, \ + void* restrict alpha, \ + void* restrict a, inc_t rs_a, inc_t cs_a, \ + void* restrict d, inc_t incd, \ + void* restrict b, inc_t rs_b, inc_t cs_b, \ + void* restrict beta, \ + void* restrict c, inc_t rs_c, inc_t cs_c, \ + cntx_t* restrict cntx, \ + rntm_t* restrict rntm, \ + thrinfo_t* restrict thread \ + ) \ +{ \ + const num_t dt = PASTEMAC(ch,type); \ +\ + /* Query the context for various blocksizes. */ \ + const dim_t NR = bli_cntx_get_blksz_def_dt( dt, BLIS_NR, cntx ); \ + const dim_t MR = bli_cntx_get_blksz_def_dt( dt, BLIS_MR, cntx ); \ + const dim_t NC = bli_cntx_get_blksz_def_dt( dt, BLIS_NC, cntx ); \ + const dim_t MC = bli_cntx_get_blksz_def_dt( dt, BLIS_MC, cntx ); \ + const dim_t KC = bli_cntx_get_blksz_def_dt( dt, BLIS_KC, cntx ); \ +\ + /* Query the context for the microkernel address and cast it to its + function pointer type. */ \ + /* + PASTECH(ch,gemm_ukr_ft) \ + gemm_ukr = bli_cntx_get_l3_nat_ukr_dt( dt, BLIS_GEMM_UKR, cntx ); \ + */ \ +\ + /* Temporary C buffer for edge cases. Note that the strides of this + temporary buffer are set so that they match the storage of the + original C matrix. For example, if C is column-stored, ct will be + column-stored as well. */ \ + /* + ctype ct[ BLIS_STACK_BUF_MAX_SIZE \ + / sizeof( ctype ) ] \ + __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \ + const bool col_pref = bli_cntx_l3_nat_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, cntx ); \ + const inc_t rs_ct = ( col_pref ? 1 : NR ); \ + const inc_t cs_ct = ( col_pref ? MR : 1 ); \ + */ \ +\ + /* Compute partitioning step values for each matrix of each loop. */ \ + const inc_t jcstep_c = cs_c; \ + const inc_t jcstep_b = cs_b; \ +\ + const inc_t pcstep_a = cs_a; \ + const inc_t pcstep_d = incd; \ + const inc_t pcstep_b = rs_b; \ +\ + const inc_t icstep_c = rs_c; \ + const inc_t icstep_a = rs_a; \ +\ + const inc_t jrstep_c = cs_c * NR; \ +\ + const inc_t irstep_c = rs_c * MR; \ +\ + ctype* restrict a_00 = a; \ + ctype* restrict d_00 = d; \ + ctype* restrict b_00 = b; \ + ctype* restrict c_00 = c; \ + ctype* restrict alpha_cast = alpha; \ + ctype* restrict beta_cast = beta; \ +\ + /* Make local copies of the scalars to prevent any unnecessary sharing of + cache lines between the cores' caches. */ \ + ctype alpha_local = *alpha_cast; \ + ctype beta_local = *beta_cast; \ + ctype one_local = *PASTEMAC(ch,1); \ + /*ctype zero_local = *PASTEMAC(ch,0);*/ \ +\ + auxinfo_t aux; \ +\ + /* Initialize a mem_t entry for A and B. Strictly speaking, this is only + needed for the matrix we will be packing (if any), but we do it + unconditionally to be safe. */ \ + mem_t mem_a = BLIS_MEM_INITIALIZER; \ + mem_t mem_b = BLIS_MEM_INITIALIZER; \ +\ + /* Define an array of bszid_t ids, which will act as our substitute for + the cntl_t tree. */ \ + bszid_t bszids[8] = { BLIS_NC, /* 5th loop */ \ + BLIS_KC, /* 4th loop */ \ + BLIS_NO_PART, /* pack B */ \ + BLIS_MC, /* 3rd loop */ \ + BLIS_NO_PART, /* pack A */ \ + BLIS_NR, /* 2nd loop */ \ + BLIS_MR, /* 1st loop */ \ + BLIS_KR }; /* microkernel loop */ \ +\ + bszid_t* restrict bszids_jc = &bszids[0]; \ + bszid_t* restrict bszids_pc = &bszids[1]; \ + /*bszid_t* restrict bszids_pb = &bszids[2];*/ \ + bszid_t* restrict bszids_ic = &bszids[3]; \ + /*bszid_t* restrict bszids_pa = &bszids[4];*/ \ + bszid_t* restrict bszids_jr = &bszids[5]; \ + /*bszid_t* restrict bszids_ir = &bszids[6];*/ \ +\ + thrinfo_t* restrict thread_jc = NULL; \ + thrinfo_t* restrict thread_pc = NULL; \ + thrinfo_t* restrict thread_pb = NULL; \ + thrinfo_t* restrict thread_ic = NULL; \ + thrinfo_t* restrict thread_pa = NULL; \ + thrinfo_t* restrict thread_jr = NULL; \ + thrinfo_t* restrict thread_ir = NULL; \ +\ + /* Identify the current thrinfo_t node and then grow the tree. */ \ + thread_jc = thread; \ + bli_thrinfo_sup_grow( rntm, bszids_jc, thread_jc ); \ +\ + /* Compute the JC loop thread range for the current thread. */ \ + dim_t jc_start, jc_end; \ + bli_thread_range_sub( thread_jc, n, NR, FALSE, &jc_start, &jc_end ); \ + const dim_t n_local = jc_end - jc_start; \ +\ + /* Compute number of primary and leftover components of the JC loop. */ \ + /*const dim_t jc_iter = ( n_local + NC - 1 ) / NC;*/ \ + const dim_t jc_left = n_local % NC; \ +\ + /* Loop over the n dimension (NC rows/columns at a time). */ \ + for ( dim_t jj = jc_start; jj < jc_end; jj += NC ) \ + { \ + /* Calculate the thread's current JC block dimension. */ \ + const dim_t nc_cur = ( NC <= jc_end - jj ? NC : jc_left ); \ +\ + ctype* restrict b_jc = b_00 + jj * jcstep_b; \ + ctype* restrict c_jc = c_00 + jj * jcstep_c; \ +\ + /* Identify the current thrinfo_t node and then grow the tree. */ \ + thread_pc = bli_thrinfo_sub_node( thread_jc ); \ + bli_thrinfo_sup_grow( rntm, bszids_pc, thread_pc ); \ +\ + /* Compute the PC loop thread range for the current thread. */ \ + const dim_t pc_start = 0, pc_end = k; \ + const dim_t k_local = k; \ +\ + /* Compute number of primary and leftover components of the PC loop. */ \ + /*const dim_t pc_iter = ( k_local + KC - 1 ) / KC;*/ \ + const dim_t pc_left = k_local % KC; \ +\ + /* Loop over the k dimension (KC rows/columns at a time). */ \ + for ( dim_t pp = pc_start; pp < pc_end; pp += KC ) \ + { \ + /* Calculate the thread's current PC block dimension. */ \ + const dim_t kc_cur = ( KC <= pc_end - pp ? KC : pc_left ); \ +\ + ctype* restrict a_pc = a_00 + pp * pcstep_a; \ + ctype* restrict d_pc = d_00 + pp * pcstep_d; \ + ctype* restrict b_pc = b_jc + pp * pcstep_b; \ +\ + /* Only apply beta to the first iteration of the pc loop. */ \ + ctype* restrict beta_use = ( pp == 0 ? &beta_local : &one_local ); \ +\ + ctype* b_use; \ + inc_t rs_b_use, cs_b_use, ps_b_use; \ +\ + /* Identify the current thrinfo_t node. Note that the thrinfo_t + node will have already been created by a previous call to + bli_thrinfo_sup_grow() since bszid_t values of BLIS_NO_PART + cause the tree to grow by two (e.g. to the next bszid that is + a normal bszid_t value). */ \ + thread_pb = bli_thrinfo_sub_node( thread_pc ); \ + /*bli_thrinfo_sup_grow( rntm, bszids_pb, thread_pb );*/ \ +\ + /* Determine the packing buffer and related parameters for matrix + B. Then call the packm implementation. */ \ + PASTECH2(bao_,ch,packm_b) \ + ( \ + conjb, \ + KC, NC, \ + kc_cur, nc_cur, NR, \ + &one_local, \ + d_pc, incd, \ + b_pc, rs_b, cs_b, \ + &b_use, &rs_b_use, &cs_b_use, \ + &ps_b_use, \ + cntx, \ + rntm, \ + &mem_b, \ + thread_pb \ + ); \ +\ + /* Alias b_use so that it's clear this is our current block of + matrix B. */ \ + ctype* restrict b_pc_use = b_use; \ +\ + /* Identify the current thrinfo_t node and then grow the tree. */ \ + thread_ic = bli_thrinfo_sub_node( thread_pb ); \ + bli_thrinfo_sup_grow( rntm, bszids_ic, thread_ic ); \ +\ + /* Compute the IC loop thread range for the current thread. */ \ + dim_t ic_start, ic_end; \ + bli_thread_range_sub( thread_ic, m, MR, FALSE, &ic_start, &ic_end ); \ + const dim_t m_local = ic_end - ic_start; \ +\ + /* Compute number of primary and leftover components of the IC loop. */ \ + /*const dim_t ic_iter = ( m_local + MC - 1 ) / MC;*/ \ + const dim_t ic_left = m_local % MC; \ +\ + /* Loop over the m dimension (MC rows at a time). */ \ + for ( dim_t ii = ic_start; ii < ic_end; ii += MC ) \ + { \ + /* Calculate the thread's current IC block dimension. */ \ + const dim_t mc_cur = ( MC <= ic_end - ii ? MC : ic_left ); \ +\ + ctype* restrict a_ic = a_pc + ii * icstep_a; \ + ctype* restrict c_ic = c_jc + ii * icstep_c; \ +\ + ctype* a_use; \ + inc_t rs_a_use, cs_a_use, ps_a_use; \ +\ + /* Identify the current thrinfo_t node. Note that the thrinfo_t + node will have already been created by a previous call to + bli_thrinfo_sup_grow() since bszid_t values of BLIS_NO_PART + cause the tree to grow by two (e.g. to the next bszid that is + a normal bszid_t value). */ \ + thread_pa = bli_thrinfo_sub_node( thread_ic ); \ + /*bli_thrinfo_sup_grow( rntm, bszids_pa, thread_pa );*/ \ +\ + /* Determine the packing buffer and related parameters for matrix + A. Then call the packm implementation. */ \ + PASTECH2(bao_,ch,packm_a) \ + ( \ + conja, \ + MC, KC, \ + mc_cur, kc_cur, MR, \ + &one_local, \ + d_pc, incd, \ + a_ic, rs_a, cs_a, \ + &a_use, &rs_a_use, &cs_a_use, \ + &ps_a_use, \ + cntx, \ + rntm, \ + &mem_a, \ + thread_pa \ + ); \ +\ + /* Alias a_use so that it's clear this is our current block of + matrix A. */ \ + ctype* restrict a_ic_use = a_use; \ +\ + /* Identify the current thrinfo_t node and then grow the tree. */ \ + thread_jr = bli_thrinfo_sub_node( thread_pa ); \ + bli_thrinfo_sup_grow( rntm, bszids_jr, thread_jr ); \ +\ + /* Query the number of threads and thread ids for the JR loop. + NOTE: These values are only needed when computing the next + micropanel of B. */ \ + const dim_t jr_nt = bli_thread_n_way( thread_jr ); \ + const dim_t jr_tid = bli_thread_work_id( thread_jr ); \ +\ + /* Compute number of primary and leftover components of the JR loop. */ \ + dim_t jr_iter = ( nc_cur + NR - 1 ) / NR; \ + dim_t jr_left = nc_cur % NR; \ +\ + /* Compute the JR loop thread range for the current thread. */ \ + dim_t jr_start, jr_end; \ + bli_thread_range_sub( thread_jr, jr_iter, 1, FALSE, &jr_start, &jr_end ); \ +\ + /* Loop over the n dimension (NR columns at a time). */ \ + for ( dim_t j = jr_start; j < jr_end; j += 1 ) \ + { \ + const dim_t nr_cur \ + = ( bli_is_not_edge_f( j, jr_iter, jr_left ) ? NR : jr_left ); \ +\ + ctype* restrict b_jr = b_pc_use + j * ps_b_use; \ + ctype* restrict c_jr = c_ic + j * jrstep_c; \ +\ + /* Assume for now that our next panel of B to be the current panel + of B. */ \ + ctype* restrict b2 = b_jr; \ +\ + /* Identify the current thrinfo_t node. */ \ + thread_ir = bli_thrinfo_sub_node( thread_jr ); \ +\ + /* Query the number of threads and thread ids for the IR loop. + NOTE: These values are only needed when computing the next + micropanel of A. */ \ + const dim_t ir_nt = bli_thread_n_way( thread_ir ); \ + const dim_t ir_tid = bli_thread_work_id( thread_ir ); \ +\ + /* Compute number of primary and leftover components of the IR loop. */ \ + dim_t ir_iter = ( mc_cur + MR - 1 ) / MR; \ + dim_t ir_left = mc_cur % MR; \ +\ + /* Compute the IR loop thread range for the current thread. */ \ + dim_t ir_start, ir_end; \ + bli_thread_range_sub( thread_ir, ir_iter, 1, FALSE, &ir_start, &ir_end ); \ +\ + /* Loop over the m dimension (MR rows at a time). */ \ + for ( dim_t i = ir_start; i < ir_end; i += 1 ) \ + { \ + const dim_t mr_cur \ + = ( bli_is_not_edge_f( i, ir_iter, ir_left ) ? MR : ir_left ); \ +\ + ctype* restrict a_ir = a_ic_use + i * ps_a_use; \ + ctype* restrict c_ir = c_jr + i * irstep_c; \ +\ + ctype* restrict a2; \ +\ + /* Compute the addresses of the next micropanels of A and B. */ \ + a2 = bli_gemm_get_next_a_upanel( a_ir, ps_a_use, 1 ); \ + if ( bli_is_last_iter( i, ir_end, ir_tid, ir_nt ) ) \ + { \ + a2 = a_ic_use; \ + b2 = bli_gemm_get_next_b_upanel( b_jr, ps_b_use, 1 ); \ + if ( bli_is_last_iter( j, jr_end, jr_tid, jr_nt ) ) \ + b2 = b_pc_use; \ + } \ +\ + /* Save the addresses of next micropanels of A and B to the + auxinfo_t object. */ \ + bli_auxinfo_set_next_a( a2, &aux ); \ + bli_auxinfo_set_next_b( b2, &aux ); \ +\ + /* Call a wrapper to the kernel (which handles edge cases). */ \ + PASTECH2(bao_,ch,gemm_kernel) \ + ( \ + MR, \ + NR, \ + mr_cur, \ + nr_cur, \ + kc_cur, \ + &alpha_local, \ + a_ir, rs_a_use, cs_a_use, \ + b_jr, rs_b_use, cs_b_use, \ + beta_use, \ + c_ir, rs_c, cs_c, \ + &aux, \ + cntx \ + ); \ + } \ + } \ + } \ +\ + /* This barrier is needed to prevent threads from starting to pack + the next row panel of B before the current row panel is fully + computed upon. */ \ + bli_thread_barrier( thread_pb ); \ + } \ + } \ +\ + /* Release any memory that was acquired for packing matrices A and B. */ \ + PASTECH2(bao_,ch,packm_finalize_mem_a) \ + ( \ + rntm, \ + &mem_a, \ + thread_pa \ + ); \ + PASTECH2(bao_,ch,packm_finalize_mem_b) \ + ( \ + rntm, \ + &mem_b, \ + thread_pb \ + ); \ +\ +/* +PASTEMAC(ch,fprintm)( stdout, "gemmd_bp_var2: a1_packed", mr_cur, kc_cur, a_ir, rs_a_use, cs_a_use, "%5.2f", "" ); \ +PASTEMAC(ch,fprintm)( stdout, "gemmd_bp_var2: b1_packed", kc_cur, nr_cur, b_jr, rs_b_use, cs_b_use, "%5.2f", "" ); \ +PASTEMAC(ch,fprintm)( stdout, "gemmd_bp_var2: c ", mr_cur, nr_cur, c_ir, rs_c, cs_c, "%5.2f", "" ); \ +*/ \ +} + +//INSERT_GENTFUNC_BASIC0( gemmd_bp_var2 ) +GENTFUNC( float, s, gemmd_bp_var2 ) +GENTFUNC( double, d, gemmd_bp_var2 ) +GENTFUNC( scomplex, c, gemmd_bp_var2 ) +GENTFUNC( dcomplex, z, gemmd_bp_var2 ) + +// +// -- gemm-like microkernel wrapper -------------------------------------------- +// + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, varname ) \ +\ +void PASTECH2(bao_,ch,varname) \ + ( \ + const dim_t MR, \ + const dim_t NR, \ + dim_t mr_cur, \ + dim_t nr_cur, \ + dim_t kc_cur, \ + ctype* restrict alpha, \ + ctype* restrict a, inc_t rs_a, inc_t cs_a, \ + ctype* restrict b, inc_t rs_b, inc_t cs_b, \ + ctype* restrict beta, \ + ctype* restrict c, inc_t rs_c, inc_t cs_c, \ + auxinfo_t* restrict aux, \ + cntx_t* restrict cntx \ + ) \ +{ \ + /* Infer the datatype from the ctype. */ \ + const num_t dt = PASTEMAC(ch,type); \ +\ + /* Query the context for the microkernel address and cast it to its + function pointer type. */ \ + PASTECH(ch,gemm_ukr_ft) \ + gemm_ukr = bli_cntx_get_l3_nat_ukr_dt( dt, BLIS_GEMM_UKR, cntx ); \ +\ + /* Temporary C buffer for edge cases. Note that the strides of this + temporary buffer are set so that they match the storage of the + original C matrix. For example, if C is column-stored, ct will be + column-stored as well. */ \ + ctype ct[ BLIS_STACK_BUF_MAX_SIZE \ + / sizeof( ctype ) ] \ + __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \ + const bool col_pref = bli_cntx_l3_nat_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, cntx ); \ + const inc_t rs_ct = ( col_pref ? 1 : NR ); \ + const inc_t cs_ct = ( col_pref ? MR : 1 ); \ +\ + ctype zero = *PASTEMAC(ch,0); \ +\ + /* Handle interior and edge cases separately. */ \ + if ( mr_cur == MR && nr_cur == NR ) \ + { \ + /* Invoke the gemm microkernel. */ \ + gemm_ukr \ + ( \ + kc_cur, \ + alpha, \ + a, \ + b, \ + beta, \ + c, rs_c, cs_c, \ + aux, \ + cntx \ + ); \ + } \ + else \ + { \ + /* Invoke the gemm microkernel. */ \ + gemm_ukr \ + ( \ + kc_cur, \ + alpha, \ + a, \ + b, \ + &zero, \ + ct, rs_ct, cs_ct, \ + aux, \ + cntx \ + ); \ +\ + /* Scale the bottom edge of C and add the result from above. */ \ + PASTEMAC(ch,xpbys_mxn) \ + ( \ + mr_cur, \ + nr_cur, \ + ct, rs_ct, cs_ct, \ + beta, \ + c, rs_c, cs_c \ + ); \ + } \ +} + +//INSERT_GENTFUNC_BASIC0( gemm_kernel ) +GENTFUNC( float, s, gemm_kernel ) +GENTFUNC( double, d, gemm_kernel ) +GENTFUNC( scomplex, c, gemm_kernel ) +GENTFUNC( dcomplex, z, gemm_kernel ) + diff --git a/addon/gemmd/bao_gemmd_check.c b/addon/gemmd/bao_gemmd_check.c new file mode 100644 index 0000000000..864e9a1acb --- /dev/null +++ b/addon/gemmd/bao_gemmd_check.c @@ -0,0 +1,131 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +void bao_gemmd_check + ( + obj_t* alpha, + obj_t* a, + obj_t* d, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx + ) +{ + err_t e_val; + + // Check object datatypes. + + e_val = bli_check_noninteger_object( alpha ); + bli_check_error_code( e_val ); + + e_val = bli_check_noninteger_object( beta ); + bli_check_error_code( e_val ); + + e_val = bli_check_floating_object( a ); + bli_check_error_code( e_val ); + + e_val = bli_check_floating_object( d ); + bli_check_error_code( e_val ); + + e_val = bli_check_floating_object( b ); + bli_check_error_code( e_val ); + + e_val = bli_check_floating_object( c ); + bli_check_error_code( e_val ); + + // Check scalar/vector/matrix type. + + e_val = bli_check_scalar_object( alpha ); + bli_check_error_code( e_val ); + + e_val = bli_check_scalar_object( beta ); + bli_check_error_code( e_val ); + + e_val = bli_check_matrix_object( a ); + bli_check_error_code( e_val ); + + e_val = bli_check_vector_object( d ); + bli_check_error_code( e_val ); + + e_val = bli_check_matrix_object( b ); + bli_check_error_code( e_val ); + + e_val = bli_check_matrix_object( c ); + bli_check_error_code( e_val ); + + // Check object buffers (for non-NULLness). + + e_val = bli_check_object_buffer( alpha ); + bli_check_error_code( e_val ); + + e_val = bli_check_object_buffer( a ); + bli_check_error_code( e_val ); + + e_val = bli_check_object_buffer( d ); + bli_check_error_code( e_val ); + + e_val = bli_check_object_buffer( b ); + bli_check_error_code( e_val ); + + e_val = bli_check_object_buffer( beta ); + bli_check_error_code( e_val ); + + e_val = bli_check_object_buffer( c ); + bli_check_error_code( e_val ); + + // Check object dimensions. + + e_val = bli_check_level3_dims( a, b, c ); + bli_check_error_code( e_val ); + + e_val = bli_check_vector_dim_equals( d, bli_obj_width_after_trans( a ) ); + bli_check_error_code( e_val ); + + // Check for consistent datatypes. + // NOTE: We only perform these tests when mixed datatype support is + // disabled. + + e_val = bli_check_consistent_object_datatypes( c, a ); + bli_check_error_code( e_val ); + + e_val = bli_check_consistent_object_datatypes( c, d ); + bli_check_error_code( e_val ); + + e_val = bli_check_consistent_object_datatypes( c, b ); + bli_check_error_code( e_val ); +} + diff --git a/addon/gemmd/bao_gemmd_check.h b/addon/gemmd/bao_gemmd_check.h new file mode 100644 index 0000000000..243ec70c8c --- /dev/null +++ b/addon/gemmd/bao_gemmd_check.h @@ -0,0 +1,50 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + + +// +// Prototype object-based check functions. +// + +void bao_gemmd_check + ( + obj_t* alpha, + obj_t* a, + obj_t* d, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx + ); + diff --git a/addon/gemmd/bao_gemmd_var.h b/addon/gemmd/bao_gemmd_var.h new file mode 100644 index 0000000000..5c66747275 --- /dev/null +++ b/addon/gemmd/bao_gemmd_var.h @@ -0,0 +1,126 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + + +// +// Prototype the object-based variant interfaces. +// + +#undef GENPROT +#define GENPROT( opname ) \ +\ +void PASTECH(bao_,opname) \ + ( \ + obj_t* alpha, \ + obj_t* a, \ + obj_t* d, \ + obj_t* b, \ + obj_t* beta, \ + obj_t* c, \ + cntx_t* cntx, \ + rntm_t* rntm, \ + thrinfo_t* thread \ + ); + +GENPROT( gemmd_bp_var1 ) +GENPROT( gemmd_bp_var2 ) + + +// +// Prototype the typed variant interfaces. +// + +#undef GENTPROT +#define GENTPROT( ctype, ch, varname ) \ +\ +void PASTECH2(bao_,ch,varname) \ + ( \ + conj_t conja, \ + conj_t conjb, \ + dim_t m, \ + dim_t n, \ + dim_t k, \ + void* restrict alpha, \ + void* restrict a, inc_t rs_a, inc_t cs_a, \ + void* restrict d, inc_t incd, \ + void* restrict b, inc_t rs_b, inc_t cs_b, \ + void* restrict beta, \ + void* restrict c, inc_t rs_c, inc_t cs_c, \ + cntx_t* restrict cntx, \ + rntm_t* restrict rntm, \ + thrinfo_t* restrict thread \ + ); + +//INSERT_GENTPROT_BASIC0( gemmd_bp_var1 ) +GENTPROT( float, s, gemmd_bp_var1 ) +GENTPROT( double, d, gemmd_bp_var1 ) +GENTPROT( scomplex, c, gemmd_bp_var1 ) +GENTPROT( dcomplex, z, gemmd_bp_var1 ) + +//INSERT_GENTPROT_BASIC0( gemmd_bp_var2 ) +GENTPROT( float, s, gemmd_bp_var2 ) +GENTPROT( double, d, gemmd_bp_var2 ) +GENTPROT( scomplex, c, gemmd_bp_var2 ) +GENTPROT( dcomplex, z, gemmd_bp_var2 ) + + +// +// Prototype the typed kernel interfaces. +// + +#undef GENTPROT +#define GENTPROT( ctype, ch, varname ) \ +\ +void PASTECH2(bao_,ch,varname) \ + ( \ + const dim_t MR, \ + const dim_t NR, \ + dim_t mr_cur, \ + dim_t nr_cur, \ + dim_t k, \ + ctype* restrict alpha, \ + ctype* restrict a, inc_t rs_a, inc_t cs_a, \ + ctype* restrict b, inc_t rs_b, inc_t cs_b, \ + ctype* restrict beta, \ + ctype* restrict c, inc_t rs_c, inc_t cs_c, \ + auxinfo_t* restrict aux, \ + cntx_t* restrict cntx \ + ); + +//INSERT_GENTPROT_BASIC0( gemm_kernel ) +GENTPROT( float, s, gemm_kernel ) +GENTPROT( double, d, gemm_kernel ) +GENTPROT( scomplex, c, gemm_kernel ) +GENTPROT( dcomplex, z, gemm_kernel ) + diff --git a/addon/gemmd/bao_l3_packm_a.c b/addon/gemmd/bao_l3_packm_a.c new file mode 100644 index 0000000000..49bb34664c --- /dev/null +++ b/addon/gemmd/bao_l3_packm_a.c @@ -0,0 +1,330 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, opname ) \ +\ +void PASTECH2(bao_,ch,opname) \ + ( \ + dim_t m, \ + dim_t k, \ + dim_t mr, \ + cntx_t* restrict cntx, \ + rntm_t* restrict rntm, \ + mem_t* restrict mem, \ + thrinfo_t* restrict thread \ + ) \ +{ \ + /* Set the pack buffer type so that we are obtaining memory blocks from + the pool dedicated to blocks of A. */ \ + const packbuf_t pack_buf_type = BLIS_BUFFER_FOR_A_BLOCK; \ +\ + /* NOTE: This "rounding up" of the last upanel is absolutely necessary since + we NEED that last micropanel to have the same ldim (cs_p) as the other + micropanels. Why? Because the microkernel assumes that the register (MR, + NR) AND storage (PACKMR, PACKNR) blocksizes do not change. */ \ + const dim_t m_pack = ( m / mr + ( m % mr ? 1 : 0 ) ) * mr; \ + const dim_t k_pack = k; \ +\ + /* Barrier to make sure all threads are caught up and ready to begin the + packm stage. */ \ + bli_thread_barrier( thread ); \ +\ + /* Compute the size of the memory block eneded. */ \ + siz_t size_needed = sizeof( ctype ) * m_pack * k_pack; \ +\ + /* Check the mem_t entry provided by the caller. If it is unallocated, + then we need to acquire a block from the packed block allocator. */ \ + if ( bli_mem_is_unalloc( mem ) ) \ + { \ + if ( bli_thread_am_ochief( thread ) ) \ + { \ + /* Acquire directly to the chief thread's mem_t that was passed in. + It needs to be that mem_t struct, and not a local (temporary) + mem_t, since there is no barrier until after packing is finished, + which could allow a race condition whereby the chief thread exits + the current function before the other threads have a chance to + copy from it. (A barrier would fix that race condition, but then + again, I prefer to keep barriers to a minimum.) */ \ + bli_pba_acquire_m \ + ( \ + rntm, \ + size_needed, \ + pack_buf_type, \ + mem \ + ); \ + } \ +\ + /* Broadcast the address of the chief thread's passed-in mem_t to all + threads. */ \ + mem_t* mem_p = bli_thread_broadcast( thread, mem ); \ +\ + /* Non-chief threads: Copy the contents of the chief thread's + passed-in mem_t to the passed-in mem_t for this thread. (The + chief thread already has the mem_t, so it does not need to + perform any copy.) */ \ + if ( !bli_thread_am_ochief( thread ) ) \ + { \ + *mem = *mem_p; \ + } \ + } \ + else /* if ( bli_mem_is_alloc( mem ) ) */ \ + { \ + /* If the mem_t entry provided by the caller does NOT contain a NULL + buffer, then a block has already been acquired from the packed + block allocator and cached by the caller. */ \ +\ + /* As a sanity check, we should make sure that the mem_t object isn't + associated with a block that is too small compared to the size of + the packed matrix buffer that is needed, according to the value + computed above. */ \ + siz_t mem_size = bli_mem_size( mem ); \ +\ + if ( mem_size < size_needed ) \ + { \ + if ( bli_thread_am_ochief( thread ) ) \ + { \ + /* The chief thread releases the existing block associated + with the mem_t, and then re-acquires a new block, saving + the associated mem_t to its passed-in mem_t. (See coment + above for why the acquisition needs to be directly to + the chief thread's passed-in mem_t and not a local + (temporary) mem_t. */ \ + bli_pba_release \ + ( \ + rntm, \ + mem \ + ); \ + bli_pba_acquire_m \ + ( \ + rntm, \ + size_needed, \ + pack_buf_type, \ + mem \ + ); \ + } \ +\ + /* Broadcast the address of the chief thread's passed-in mem_t + to all threads. */ \ + mem_t* mem_p = bli_thread_broadcast( thread, mem ); \ +\ + /* Non-chief threads: Copy the contents of the chief thread's + passed-in mem_t to the passed-in mem_t for this thread. (The + chief thread already has the mem_t, so it does not need to + perform any copy.) */ \ + if ( !bli_thread_am_ochief( thread ) ) \ + { \ + *mem = *mem_p; \ + } \ + } \ + else \ + { \ + /* If the mem_t entry is already allocated and sufficiently large, + then we use it as-is. No action is needed. */ \ + } \ + } \ +} + +//INSERT_GENTFUNC_BASIC0( packm_init_mem_a ) +GENTFUNC( float, s, packm_init_mem_a ) +GENTFUNC( double, d, packm_init_mem_a ) +GENTFUNC( scomplex, c, packm_init_mem_a ) +GENTFUNC( dcomplex, z, packm_init_mem_a ) + + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, opname ) \ +\ +void PASTECH2(bao_,ch,opname) \ + ( \ + rntm_t* restrict rntm, \ + mem_t* restrict mem, \ + thrinfo_t* restrict thread \ + ) \ +{ \ + if ( thread != NULL ) \ + if ( bli_thread_am_ochief( thread ) ) \ + { \ + /* Check the mem_t entry provided by the caller. Only proceed if it + is allocated, which it should be. */ \ + if ( bli_mem_is_alloc( mem ) ) \ + { \ + bli_pba_release \ + ( \ + rntm, \ + mem \ + ); \ + } \ + } \ +} + +//INSERT_GENTFUNC_BASIC0( packm_finalize_mem_a ) +GENTFUNC( float, s, packm_finalize_mem_a ) +GENTFUNC( double, d, packm_finalize_mem_a ) +GENTFUNC( scomplex, c, packm_finalize_mem_a ) +GENTFUNC( dcomplex, z, packm_finalize_mem_a ) + + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, opname ) \ +\ +void PASTECH2(bao_,ch,opname) \ + ( \ + pack_t* restrict schema, \ + dim_t m, \ + dim_t k, \ + dim_t mr, \ + dim_t* restrict m_max, \ + dim_t* restrict k_max, \ + ctype** p, inc_t* restrict rs_p, inc_t* restrict cs_p, \ + dim_t* restrict pd_p, inc_t* restrict ps_p, \ + mem_t* restrict mem \ + ) \ +{ \ + /* NOTE: This "rounding up" of the last upanel is absolutely necessary since + we NEED that last micropanel to have the same ldim (cs_p) as the other + micropanels. Why? Because the microkernel assumes that the register (MR, + NR) AND storage (PACKMR, PACKNR) blocksizes do not change. */ \ + *m_max = ( m / mr + ( m % mr ? 1 : 0 ) ) * mr; \ + *k_max = k; \ +\ + /* Determine the dimensions and strides for the packed matrix A. */ \ + { \ + /* Pack A to column-stored row-panels. */ \ + *rs_p = 1; \ + *cs_p = mr; \ +\ + *pd_p = mr; \ + *ps_p = mr * k; \ +\ + /* Set the schema to "packed row panels" to indicate packing to + conventional column-stored row panels. */ \ + *schema = BLIS_PACKED_ROW_PANELS; \ + } \ +\ + /* Set the buffer address provided by the caller to point to the memory + associated with the mem_t entry acquired from the memory pool. */ \ + *p = bli_mem_buffer( mem ); \ +} + +//INSERT_GENTFUNC_BASIC0( packm_init_a ) +GENTFUNC( float, s, packm_init_a ) +GENTFUNC( double, d, packm_init_a ) +GENTFUNC( scomplex, c, packm_init_a ) +GENTFUNC( dcomplex, z, packm_init_a ) + + +// +// Define BLAS-like interfaces to the variant chooser. +// + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, opname ) \ +\ +void PASTECH2(bao_,ch,opname) \ + ( \ + conj_t conj, \ + dim_t m_alloc, \ + dim_t k_alloc, \ + dim_t m, \ + dim_t k, \ + dim_t mr, \ + ctype* restrict kappa, \ + ctype* restrict d, inc_t incd, \ + ctype* restrict a, inc_t rs_a, inc_t cs_a, \ + ctype** restrict p, inc_t* restrict rs_p, inc_t* restrict cs_p, \ + inc_t* restrict ps_p, \ + cntx_t* restrict cntx, \ + rntm_t* restrict rntm, \ + mem_t* restrict mem, \ + thrinfo_t* restrict thread \ + ) \ +{ \ + pack_t schema; \ + dim_t m_max; \ + dim_t k_max; \ + dim_t pd_p; \ +\ + /* Prepare the packing destination buffer. */ \ + PASTECH2(bao_,ch,packm_init_mem_a) \ + ( \ + m_alloc, k_alloc, mr, \ + cntx, \ + rntm, \ + mem, \ + thread \ + ); \ +\ + /* Determine the packing buffer and related parameters for matrix A. */ \ + PASTECH2(bao_,ch,packm_init_a) \ + ( \ + &schema, \ + m, k, mr, \ + &m_max, &k_max, \ + p, rs_p, cs_p, \ + &pd_p, ps_p, \ + mem \ + ); \ +\ + /* Pack matrix A to the destination buffer chosen above. Here, the packed + matrix is stored to column-stored MR x k micropanels. */ \ + PASTECH2(bao_,ch,packm_var1) \ + ( \ + conj, \ + schema, \ + m, \ + k, \ + m_max, \ + k_max, \ + kappa, \ + d, incd, \ + a, rs_a, cs_a, \ + *p, *rs_p, *cs_p, \ + pd_p, *ps_p, \ + cntx, \ + thread \ + ); \ +\ + /* Barrier so that packing is done before computation. */ \ + bli_thread_barrier( thread ); \ +} + +//INSERT_GENTFUNC_BASIC0( packm_a ) +GENTFUNC( float, s, packm_a ) +GENTFUNC( double, d, packm_a ) +GENTFUNC( scomplex, c, packm_a ) +GENTFUNC( dcomplex, z, packm_a ) + diff --git a/addon/gemmd/bao_l3_packm_a.h b/addon/gemmd/bao_l3_packm_a.h new file mode 100644 index 0000000000..b683b79d4a --- /dev/null +++ b/addon/gemmd/bao_l3_packm_a.h @@ -0,0 +1,123 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#undef GENTPROT +#define GENTPROT( ctype, ch, opname ) \ +\ +void PASTECH2(bao_,ch,opname) \ + ( \ + dim_t m, \ + dim_t k, \ + dim_t mr, \ + cntx_t* restrict cntx, \ + rntm_t* restrict rntm, \ + mem_t* restrict mem, \ + thrinfo_t* restrict thread \ + ); \ + +//INSERT_GENTPROT_BASIC0( packm_init_mem_a ) +GENTPROT( float, s, packm_init_mem_a ) +GENTPROT( double, d, packm_init_mem_a ) +GENTPROT( scomplex, c, packm_init_mem_a ) +GENTPROT( dcomplex, z, packm_init_mem_a ) + + +#undef GENTPROT +#define GENTPROT( ctype, ch, opname ) \ +\ +void PASTECH2(bao_,ch,opname) \ + ( \ + rntm_t* restrict rntm, \ + mem_t* restrict mem, \ + thrinfo_t* restrict thread \ + ); \ + +//INSERT_GENTPROT_BASIC0( packm_finalize_mem_a ) +GENTPROT( float, s, packm_finalize_mem_a ) +GENTPROT( double, d, packm_finalize_mem_a ) +GENTPROT( scomplex, c, packm_finalize_mem_a ) +GENTPROT( dcomplex, z, packm_finalize_mem_a ) + + +#undef GENTPROT +#define GENTPROT( ctype, ch, opname ) \ +\ +void PASTECH2(bao_,ch,opname) \ + ( \ + pack_t* restrict schema, \ + dim_t m, \ + dim_t k, \ + dim_t mr, \ + dim_t* restrict m_max, \ + dim_t* restrict k_max, \ + ctype** p, inc_t* restrict rs_p, inc_t* restrict cs_p, \ + dim_t* restrict pd_p, inc_t* restrict ps_p, \ + mem_t* restrict mem \ + ); \ + +//INSERT_GENTPROT_BASIC0( packm_init_a ) +GENTPROT( float, s, packm_init_a ) +GENTPROT( double, d, packm_init_a ) +GENTPROT( scomplex, c, packm_init_a ) +GENTPROT( dcomplex, z, packm_init_a ) + + +#undef GENTPROT +#define GENTPROT( ctype, ch, opname ) \ +\ +void PASTECH2(bao_,ch,opname) \ + ( \ + conj_t conj, \ + dim_t m_alloc, \ + dim_t k_alloc, \ + dim_t m, \ + dim_t k, \ + dim_t mr, \ + ctype* restrict kappa, \ + ctype* restrict d, inc_t incd, \ + ctype* restrict a, inc_t rs_a, inc_t cs_a, \ + ctype** restrict p, inc_t* restrict rs_p, inc_t* restrict cs_p, \ + inc_t* restrict ps_p, \ + cntx_t* restrict cntx, \ + rntm_t* restrict rntm, \ + mem_t* restrict mem, \ + thrinfo_t* restrict thread \ + ); \ + +//INSERT_GENTPROT_BASIC0( packm_a ) +GENTPROT( float, s, packm_a ) +GENTPROT( double, d, packm_a ) +GENTPROT( scomplex, c, packm_a ) +GENTPROT( dcomplex, z, packm_a ) + diff --git a/addon/gemmd/bao_l3_packm_b.c b/addon/gemmd/bao_l3_packm_b.c new file mode 100644 index 0000000000..c41b062b6e --- /dev/null +++ b/addon/gemmd/bao_l3_packm_b.c @@ -0,0 +1,330 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, opname ) \ +\ +void PASTECH2(bao_,ch,opname) \ + ( \ + dim_t k, \ + dim_t n, \ + dim_t nr, \ + cntx_t* restrict cntx, \ + rntm_t* restrict rntm, \ + mem_t* restrict mem, \ + thrinfo_t* restrict thread \ + ) \ +{ \ + /* Set the pack buffer type so that we are obtaining memory blocks from + the pool dedicated to panels of B. */ \ + const packbuf_t pack_buf_type = BLIS_BUFFER_FOR_B_PANEL; \ +\ + /* NOTE: This "rounding up" of the last upanel is absolutely necessary since + we NEED that last micropanel to have the same ldim (cs_p) as the other + micropanels. Why? Because the microkernel assumes that the register (MR, + NR) AND storage (PACKMR, PACKNR) blocksizes do not change. */ \ + const dim_t k_pack = k; \ + const dim_t n_pack = ( n / nr + ( n % nr ? 1 : 0 ) ) * nr; \ +\ + /* Barrier to make sure all threads are caught up and ready to begin the + packm stage. */ \ + bli_thread_barrier( thread ); \ +\ + /* Compute the size of the memory block eneded. */ \ + siz_t size_needed = sizeof( ctype ) * k_pack * n_pack; \ +\ + /* Check the mem_t entry provided by the caller. If it is unallocated, + then we need to acquire a block from the packed block allocator. */ \ + if ( bli_mem_is_unalloc( mem ) ) \ + { \ + if ( bli_thread_am_ochief( thread ) ) \ + { \ + /* Acquire directly to the chief thread's mem_t that was passed in. + It needs to be that mem_t struct, and not a local (temporary) + mem_t, since there is no barrier until after packing is finished, + which could allow a race condition whereby the chief thread exits + the current function before the other threads have a chance to + copy from it. (A barrier would fix that race condition, but then + again, I prefer to keep barriers to a minimum.) */ \ + bli_pba_acquire_m \ + ( \ + rntm, \ + size_needed, \ + pack_buf_type, \ + mem \ + ); \ + } \ +\ + /* Broadcast the address of the chief thread's passed-in mem_t to all + threads. */ \ + mem_t* mem_p = bli_thread_broadcast( thread, mem ); \ +\ + /* Non-chief threads: Copy the contents of the chief thread's + passed-in mem_t to the passed-in mem_t for this thread. (The + chief thread already has the mem_t, so it does not need to + perform any copy.) */ \ + if ( !bli_thread_am_ochief( thread ) ) \ + { \ + *mem = *mem_p; \ + } \ + } \ + else /* if ( bli_mem_is_alloc( mem ) ) */ \ + { \ + /* If the mem_t entry provided by the caller does NOT contain a NULL + buffer, then a block has already been acquired from the packed + block allocator and cached by the caller. */ \ +\ + /* As a sanity check, we should make sure that the mem_t object isn't + associated with a block that is too small compared to the size of + the packed matrix buffer that is needed, according to the value + computed above. */ \ + siz_t mem_size = bli_mem_size( mem ); \ +\ + if ( mem_size < size_needed ) \ + { \ + if ( bli_thread_am_ochief( thread ) ) \ + { \ + /* The chief thread releases the existing block associated + with the mem_t, and then re-acquires a new block, saving + the associated mem_t to its passed-in mem_t. (See coment + above for why the acquisition needs to be directly to + the chief thread's passed-in mem_t and not a local + (temporary) mem_t. */ \ + bli_pba_release \ + ( \ + rntm, \ + mem \ + ); \ + bli_pba_acquire_m \ + ( \ + rntm, \ + size_needed, \ + pack_buf_type, \ + mem \ + ); \ + } \ +\ + /* Broadcast the address of the chief thread's passed-in mem_t + to all threads. */ \ + mem_t* mem_p = bli_thread_broadcast( thread, mem ); \ +\ + /* Non-chief threads: Copy the contents of the chief thread's + passed-in mem_t to the passed-in mem_t for this thread. (The + chief thread already has the mem_t, so it does not need to + perform any copy.) */ \ + if ( !bli_thread_am_ochief( thread ) ) \ + { \ + *mem = *mem_p; \ + } \ + } \ + else \ + { \ + /* If the mem_t entry is already allocated and sufficiently large, + then we use it as-is. No action is needed. */ \ + } \ + } \ +} + +//INSERT_GENTFUNC_BASIC0( packm_init_mem_b ) +GENTFUNC( float, s, packm_init_mem_b ) +GENTFUNC( double, d, packm_init_mem_b ) +GENTFUNC( scomplex, c, packm_init_mem_b ) +GENTFUNC( dcomplex, z, packm_init_mem_b ) + + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, opname ) \ +\ +void PASTECH2(bao_,ch,opname) \ + ( \ + rntm_t* restrict rntm, \ + mem_t* restrict mem, \ + thrinfo_t* restrict thread \ + ) \ +{ \ + if ( thread != NULL ) \ + if ( bli_thread_am_ochief( thread ) ) \ + { \ + /* Check the mem_t entry provided by the caller. Only proceed if it + is allocated, which it should be. */ \ + if ( bli_mem_is_alloc( mem ) ) \ + { \ + bli_pba_release \ + ( \ + rntm, \ + mem \ + ); \ + } \ + } \ +} + +//INSERT_GENTFUNC_BASIC0( packm_finalize_mem_b ) +GENTFUNC( float, s, packm_finalize_mem_b ) +GENTFUNC( double, d, packm_finalize_mem_b ) +GENTFUNC( scomplex, c, packm_finalize_mem_b ) +GENTFUNC( dcomplex, z, packm_finalize_mem_b ) + + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, opname ) \ +\ +void PASTECH2(bao_,ch,opname) \ + ( \ + pack_t* restrict schema, \ + dim_t k, \ + dim_t n, \ + dim_t nr, \ + dim_t* restrict k_max, \ + dim_t* restrict n_max, \ + ctype** p, inc_t* restrict rs_p, inc_t* restrict cs_p, \ + dim_t* restrict pd_p, inc_t* restrict ps_p, \ + mem_t* restrict mem \ + ) \ +{ \ + /* NOTE: This "rounding up" of the last upanel is absolutely necessary since + we NEED that last micropanel to have the same ldim (cs_p) as the other + micropanels. Why? Because the microkernel assumes that the register (MR, + NR) AND storage (PACKMR, PACKNR) blocksizes do not change. */ \ + *k_max = k; \ + *n_max = ( n / nr + ( n % nr ? 1 : 0 ) ) * nr; \ +\ + /* Determine the dimensions and strides for the packed matrix B. */ \ + { \ + /* Pack B to row-stored column-panels. */ \ + *rs_p = nr; \ + *cs_p = 1; \ +\ + *pd_p = nr; \ + *ps_p = k * nr; \ +\ + /* Set the schema to "packed column panels" to indicate packing to + conventional row-stored column panels. */ \ + *schema = BLIS_PACKED_COL_PANELS; \ + } \ +\ + /* Set the buffer address provided by the caller to point to the memory + associated with the mem_t entry acquired from the memory pool. */ \ + *p = bli_mem_buffer( mem ); \ +} + +//INSERT_GENTFUNC_BASIC0( packm_init_b ) +GENTFUNC( float, s, packm_init_b ) +GENTFUNC( double, d, packm_init_b ) +GENTFUNC( scomplex, c, packm_init_b ) +GENTFUNC( dcomplex, z, packm_init_b ) + + +// +// Define BLAS-like interfaces to the variant chooser. +// + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, opname ) \ +\ +void PASTECH2(bao_,ch,opname) \ + ( \ + conj_t conj, \ + dim_t k_alloc, \ + dim_t n_alloc, \ + dim_t k, \ + dim_t n, \ + dim_t nr, \ + ctype* restrict kappa, \ + ctype* restrict d, inc_t incd, \ + ctype* restrict b, inc_t rs_b, inc_t cs_b, \ + ctype** restrict p, inc_t* restrict rs_p, inc_t* restrict cs_p, \ + inc_t* restrict ps_p, \ + cntx_t* restrict cntx, \ + rntm_t* restrict rntm, \ + mem_t* restrict mem, \ + thrinfo_t* restrict thread \ + ) \ +{ \ + pack_t schema; \ + dim_t k_max; \ + dim_t n_max; \ + dim_t pd_p; \ +\ + /* Prepare the packing destination buffer. */ \ + PASTECH2(bao_,ch,packm_init_mem_b) \ + ( \ + k_alloc, n_alloc, nr, \ + cntx, \ + rntm, \ + mem, \ + thread \ + ); \ +\ + /* Determine the packing buffer and related parameters for matrix B. */ \ + PASTECH2(bao_,ch,packm_init_b) \ + ( \ + &schema, \ + k, n, nr, \ + &k_max, &n_max, \ + p, rs_p, cs_p, \ + &pd_p, ps_p, \ + mem \ + ); \ +\ + /* Pack matrix B to the destination buffer chosen above. Here, the packed + matrix is stored to row-stored k x NR micropanels. */ \ + PASTECH2(bao_,ch,packm_var1) \ + ( \ + conj, \ + schema, \ + k, \ + n, \ + k_max, \ + n_max, \ + kappa, \ + d, incd, \ + b, rs_b, cs_b, \ + *p, *rs_p, *cs_p, \ + pd_p, *ps_p, \ + cntx, \ + thread \ + ); \ +\ + /* Barrier so that packing is done before computation. */ \ + bli_thread_barrier( thread ); \ +} + +//INSERT_GENTFUNC_BASIC0( packm_b ) +GENTFUNC( float, s, packm_b ) +GENTFUNC( double, d, packm_b ) +GENTFUNC( scomplex, c, packm_b ) +GENTFUNC( dcomplex, z, packm_b ) + diff --git a/addon/gemmd/bao_l3_packm_b.h b/addon/gemmd/bao_l3_packm_b.h new file mode 100644 index 0000000000..9161604ce9 --- /dev/null +++ b/addon/gemmd/bao_l3_packm_b.h @@ -0,0 +1,123 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#undef GENTPROT +#define GENTPROT( ctype, ch, opname ) \ +\ +void PASTECH2(bao_,ch,opname) \ + ( \ + dim_t k, \ + dim_t n, \ + dim_t nr, \ + cntx_t* restrict cntx, \ + rntm_t* restrict rntm, \ + mem_t* restrict mem, \ + thrinfo_t* restrict thread \ + ); \ + +//INSERT_GENTPROT_BASIC0( packm_init_mem_b ) +GENTPROT( float, s, packm_init_mem_b ) +GENTPROT( double, d, packm_init_mem_b ) +GENTPROT( scomplex, c, packm_init_mem_b ) +GENTPROT( dcomplex, z, packm_init_mem_b ) + + +#undef GENTPROT +#define GENTPROT( ctype, ch, opname ) \ +\ +void PASTECH2(bao_,ch,opname) \ + ( \ + rntm_t* restrict rntm, \ + mem_t* restrict mem, \ + thrinfo_t* restrict thread \ + ); \ + +//INSERT_GENTPROT_BASIC0( packm_finalize_mem_b ) +GENTPROT( float, s, packm_finalize_mem_b ) +GENTPROT( double, d, packm_finalize_mem_b ) +GENTPROT( scomplex, c, packm_finalize_mem_b ) +GENTPROT( dcomplex, z, packm_finalize_mem_b ) + + +#undef GENTPROT +#define GENTPROT( ctype, ch, opname ) \ +\ +void PASTECH2(bao_,ch,opname) \ + ( \ + pack_t* restrict schema, \ + dim_t k, \ + dim_t n, \ + dim_t nr, \ + dim_t* restrict k_max, \ + dim_t* restrict n_max, \ + ctype** p, inc_t* restrict rs_p, inc_t* restrict cs_p, \ + dim_t* restrict pd_p, inc_t* restrict ps_p, \ + mem_t* restrict mem \ + ); \ + +//INSERT_GENTPROT_BASIC0( packm_init_b ) +GENTPROT( float, s, packm_init_b ) +GENTPROT( double, d, packm_init_b ) +GENTPROT( scomplex, c, packm_init_b ) +GENTPROT( dcomplex, z, packm_init_b ) + + +#undef GENTPROT +#define GENTPROT( ctype, ch, opname ) \ +\ +void PASTECH2(bao_,ch,opname) \ + ( \ + conj_t conj, \ + dim_t k_alloc, \ + dim_t n_alloc, \ + dim_t k, \ + dim_t n, \ + dim_t nr, \ + ctype* restrict kappa, \ + ctype* restrict d, inc_t incd, \ + ctype* restrict b, inc_t rs_b, inc_t cs_b, \ + ctype** restrict p, inc_t* restrict rs_p, inc_t* restrict cs_p, \ + inc_t* restrict ps_p, \ + cntx_t* restrict cntx, \ + rntm_t* restrict rntm, \ + mem_t* restrict mem, \ + thrinfo_t* restrict thread \ + ); \ + +//INSERT_GENTPROT_BASIC0( packm_b ) +GENTPROT( float, s, packm_b ) +GENTPROT( double, d, packm_b ) +GENTPROT( scomplex, c, packm_b ) +GENTPROT( dcomplex, z, packm_b ) + diff --git a/addon/gemmd/bao_l3_packm_var.h b/addon/gemmd/bao_l3_packm_var.h new file mode 100644 index 0000000000..063e59e5f8 --- /dev/null +++ b/addon/gemmd/bao_l3_packm_var.h @@ -0,0 +1,69 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +// +// Prototype BLAS-like interfaces to the variants. +// + +#undef GENTPROT +#define GENTPROT( ctype, ch, varname ) \ +\ +void PASTECH2(bao_,ch,varname) \ + ( \ + trans_t transc, \ + pack_t schema, \ + dim_t m, \ + dim_t n, \ + dim_t m_max, \ + dim_t n_max, \ + ctype* restrict kappa, \ + ctype* restrict d, inc_t incd, \ + ctype* restrict c, inc_t rs_c, inc_t cs_c, \ + ctype* restrict p, inc_t rs_p, inc_t cs_p, \ + dim_t pd_p, inc_t ps_p, \ + cntx_t* restrict cntx, \ + thrinfo_t* restrict thread \ + ); + +//INSERT_GENTPROT_BASIC0( packm_var1 ) +GENTPROT( float, s, packm_var1 ) +GENTPROT( double, d, packm_var1 ) +GENTPROT( scomplex, c, packm_var1 ) +GENTPROT( dcomplex, z, packm_var1 ) + +//INSERT_GENTPROT_BASIC0( packm_var2 ) +GENTPROT( float, s, packm_var2 ) +GENTPROT( double, d, packm_var2 ) +GENTPROT( scomplex, c, packm_var2 ) +GENTPROT( dcomplex, z, packm_var2 ) diff --git a/addon/gemmd/bao_l3_packm_var1.c b/addon/gemmd/bao_l3_packm_var1.c new file mode 100644 index 0000000000..24c0a2cc13 --- /dev/null +++ b/addon/gemmd/bao_l3_packm_var1.c @@ -0,0 +1,195 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +// +// Variant 1 provides basic support for packing by calling packm_cxk(). +// + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, varname ) \ +\ +void PASTECH2(bao_,ch,varname) \ + ( \ + trans_t transc, \ + pack_t schema, \ + dim_t m, \ + dim_t n, \ + dim_t m_max, \ + dim_t n_max, \ + ctype* restrict kappa, \ + ctype* restrict d, inc_t incd, \ + ctype* restrict c, inc_t rs_c, inc_t cs_c, \ + ctype* restrict p, inc_t rs_p, inc_t cs_p, \ + dim_t pd_p, inc_t ps_p, \ + cntx_t* restrict cntx, \ + thrinfo_t* restrict thread \ + ) \ +{ \ + ctype* restrict kappa_cast = kappa; \ + ctype* restrict c_cast = c; \ + ctype* restrict p_cast = p; \ +\ + dim_t iter_dim; \ + dim_t n_iter; \ + dim_t it, ic; \ + dim_t ic0; \ + doff_t ic_inc; \ + dim_t panel_len; \ + dim_t panel_len_max; \ + dim_t panel_dim; \ + dim_t panel_dim_max; \ + inc_t incc; \ + inc_t ldc; \ + inc_t ldp; \ + conj_t conjc; \ +\ +\ + /* Extract the conjugation bit from the transposition argument. */ \ + conjc = bli_extract_conj( transc ); \ +\ + /* Create flags to incidate row or column storage. Note that the + schema bit that encodes row or column is describing the form of + micro-panel, not the storage in the micro-panel. Hence the + mismatch in "row" and "column" semantics. */ \ + bool row_stored = bli_is_col_packed( schema ); \ + /*bool col_stored = bli_is_row_packed( schema );*/ \ +\ + /* If the row storage flag indicates row storage, then we are packing + to column panels; otherwise, if the strides indicate column storage, + we are packing to row panels. */ \ + if ( row_stored ) \ + { \ + /* Prepare to pack to row-stored column panels. */ \ + iter_dim = n; \ + panel_len = m; \ + panel_len_max = m_max; \ + panel_dim_max = pd_p; \ + incc = cs_c; \ + ldc = rs_c; \ + ldp = rs_p; \ + } \ + else /* if ( col_stored ) */ \ + { \ + /* Prepare to pack to column-stored row panels. */ \ + iter_dim = m; \ + panel_len = n; \ + panel_len_max = n_max; \ + panel_dim_max = pd_p; \ + incc = rs_c; \ + ldc = cs_c; \ + ldp = cs_p; \ + } \ +\ + /* Compute the total number of iterations we'll need. */ \ + n_iter = iter_dim / panel_dim_max + ( iter_dim % panel_dim_max ? 1 : 0 ); \ +\ + /* Set the initial values and increments for indices related to C and P + based on whether reverse iteration was requested. */ \ + { \ + ic0 = 0; \ + ic_inc = panel_dim_max; \ + } \ +\ + ctype* restrict p_begin = p_cast; \ +\ + /* Query the number of threads and thread ids from the current thread's + packm thrinfo_t node. */ \ + const dim_t nt = bli_thread_n_way( thread ); \ + const dim_t tid = bli_thread_work_id( thread ); \ +\ + /* Suppress warnings in case tid isn't used (ie: as in slab partitioning). */ \ + ( void )nt; \ + ( void )tid; \ +\ + dim_t it_start, it_end, it_inc; \ +\ + /* Determine the thread range and increment using the current thread's + packm thrinfo_t node. NOTE: The definition of bli_thread_range_jrir() + will depend on whether slab or round-robin partitioning was requested + at configure-time. */ \ + bli_thread_range_jrir( thread, n_iter, 1, FALSE, &it_start, &it_end, &it_inc ); \ +\ + /* Iterate over every logical micropanel in the source matrix. */ \ + for ( ic = ic0, it = 0; it < n_iter; \ + ic += ic_inc, it += 1 ) \ + { \ + panel_dim = bli_min( panel_dim_max, iter_dim - ic ); \ +\ + ctype* restrict c_begin = c_cast + (ic )*incc; \ +\ + ctype* restrict c_use = c_begin; \ + ctype* restrict p_use = p_begin; \ +\ + /* The definition of bli_packm_my_iter() will depend on whether slab + or round-robin partitioning was requested at configure-time. (The + default is slab.) */ \ + if ( bli_packm_my_iter( it, it_start, it_end, tid, nt ) ) \ + { \ + PASTECH2(bao_,ch,packm_cxk) \ + ( \ + conjc, \ + schema, \ + panel_dim, \ + panel_dim_max, \ + panel_len, \ + panel_len_max, \ + kappa_cast, \ + d, incd, \ + c_use, incc, ldc, \ + p_use, ldp, \ + cntx \ + ); \ + } \ +\ +/* +if ( !row_stored ) \ +PASTEMAC(ch,fprintm)( stdout, "packm_var1: a packed", panel_dim_max, panel_len_max, \ + p_use, rs_p, cs_p, "%5.2f", "" ); \ +else \ +PASTEMAC(ch,fprintm)( stdout, "packm_var1: b packed", panel_len_max, panel_dim_max, \ + p_use, rs_p, cs_p, "%5.2f", "" ); \ +*/ \ +\ + p_begin += ps_p; \ + } \ +} + +//INSERT_GENTFUNC_BASIC0( packm_var1 ) +GENTFUNC( float, s, packm_var1 ) +GENTFUNC( double, d, packm_var1 ) +GENTFUNC( scomplex, c, packm_var1 ) +GENTFUNC( dcomplex, z, packm_var1 ) + diff --git a/addon/gemmd/bao_l3_packm_var2.c b/addon/gemmd/bao_l3_packm_var2.c new file mode 100644 index 0000000000..830e499b31 --- /dev/null +++ b/addon/gemmd/bao_l3_packm_var2.c @@ -0,0 +1,245 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +// +// Variant 2 is similar to variant 1, but inlines the contents of packm_cxk(). +// + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, varname ) \ +\ +void PASTECH2(bao_,ch,varname) \ + ( \ + trans_t transc, \ + pack_t schema, \ + dim_t m, \ + dim_t n, \ + dim_t m_max, \ + dim_t n_max, \ + ctype* restrict kappa, \ + ctype* restrict d, inc_t incd, \ + ctype* restrict c, inc_t rs_c, inc_t cs_c, \ + ctype* restrict p, inc_t rs_p, inc_t cs_p, \ + dim_t pd_p, inc_t ps_p, \ + cntx_t* restrict cntx, \ + thrinfo_t* restrict thread \ + ) \ +{ \ + ctype* restrict kappa_cast = kappa; \ + ctype* restrict c_cast = c; \ + ctype* restrict p_cast = p; \ +\ + dim_t iter_dim; \ + dim_t n_iter; \ + dim_t it, ic; \ + dim_t ic0; \ + doff_t ic_inc; \ + dim_t panel_len; \ + dim_t panel_len_max; \ + dim_t panel_dim; \ + dim_t panel_dim_max; \ + inc_t incc; \ + inc_t ldc; \ + inc_t ldp; \ + conj_t conjc; \ +\ +\ + /* Extract the conjugation bit from the transposition argument. */ \ + conjc = bli_extract_conj( transc ); \ +\ + /* Create flags to incidate row or column storage. Note that the + schema bit that encodes row or column is describing the form of + micro-panel, not the storage in the micro-panel. Hence the + mismatch in "row" and "column" semantics. */ \ + bool row_stored = bli_is_col_packed( schema ); \ + /*bool col_stored = bli_is_row_packed( schema );*/ \ +\ + /* If the row storage flag indicates row storage, then we are packing + to column panels; otherwise, if the strides indicate column storage, + we are packing to row panels. */ \ + if ( row_stored ) \ + { \ + /* Prepare to pack to row-stored column panels. */ \ + iter_dim = n; \ + panel_len = m; \ + panel_len_max = m_max; \ + panel_dim_max = pd_p; \ + incc = cs_c; \ + ldc = rs_c; \ + ldp = rs_p; \ + } \ + else /* if ( col_stored ) */ \ + { \ + /* Prepare to pack to column-stored row panels. */ \ + iter_dim = m; \ + panel_len = n; \ + panel_len_max = n_max; \ + panel_dim_max = pd_p; \ + incc = rs_c; \ + ldc = cs_c; \ + ldp = cs_p; \ + } \ +\ + /* Compute the total number of iterations we'll need. */ \ + n_iter = iter_dim / panel_dim_max + ( iter_dim % panel_dim_max ? 1 : 0 ); \ +\ + /* Set the initial values and increments for indices related to C and P + based on whether reverse iteration was requested. */ \ + { \ + ic0 = 0; \ + ic_inc = panel_dim_max; \ + } \ +\ + ctype* restrict p_begin = p_cast; \ +\ + /* Query the number of threads and thread ids from the current thread's + packm thrinfo_t node. */ \ + const dim_t nt = bli_thread_n_way( thread ); \ + const dim_t tid = bli_thread_work_id( thread ); \ +\ + /* Suppress warnings in case tid isn't used (ie: as in slab partitioning). */ \ + ( void )nt; \ + ( void )tid; \ +\ + dim_t it_start, it_end, it_inc; \ +\ + /* Determine the thread range and increment using the current thread's + packm thrinfo_t node. NOTE: The definition of bli_thread_range_jrir() + will depend on whether slab or round-robin partitioning was requested + at configure-time. */ \ + bli_thread_range_jrir( thread, n_iter, 1, FALSE, &it_start, &it_end, &it_inc ); \ +\ + /* Iterate over every logical micropanel in the source matrix. */ \ + for ( ic = ic0, it = 0; it < n_iter; \ + ic += ic_inc, it += 1 ) \ + { \ + panel_dim = bli_min( panel_dim_max, iter_dim - ic ); \ +\ + ctype* restrict c_begin = c_cast + (ic )*incc; \ +\ + ctype* restrict c_use = c_begin; \ + ctype* restrict p_use = p_begin; \ +\ + /* The definition of bli_packm_my_iter() will depend on whether slab + or round-robin partitioning was requested at configure-time. (The + default is slab.) */ \ + if ( bli_packm_my_iter( it, it_start, it_end, tid, nt ) ) \ + { \ + /* NOTE: We assume here that kappa = 1 and therefore ignore it. If + we're wrong, this will get someone's attention. */ \ + if ( !PASTEMAC(ch,eq1)( *kappa_cast ) ) \ + bli_abort(); \ +\ + /* Perform the packing, taking conjc into account. */ \ + if ( bli_is_conj( conjc ) ) \ + { \ + for ( dim_t l = 0; l < panel_len; ++l ) \ + { \ + for ( dim_t d = 0; d < panel_dim; ++d ) \ + { \ + ctype* cld = c_use + (l )*ldc + (d )*incc; \ + ctype* pld = p_use + (l )*ldp + (d )*1; \ +\ + PASTEMAC(ch,copyjs)( *cld, *pld ); \ + } \ + } \ + } \ + else \ + { \ + for ( dim_t l = 0; l < panel_len; ++l ) \ + { \ + for ( dim_t d = 0; d < panel_dim; ++d ) \ + { \ + ctype* cld = c_use + (l )*ldc + (d )*incc; \ + ctype* pld = p_use + (l )*ldp + (d )*1; \ +\ + PASTEMAC(ch,copys)( *cld, *pld ); \ + } \ + } \ + } \ +\ + /* If panel_dim < panel_dim_max, then we zero those unused rows. */ \ + if ( panel_dim < panel_dim_max ) \ + { \ + const dim_t i = panel_dim; \ + const dim_t m_edge = panel_dim_max - panel_dim; \ + const dim_t n_edge = panel_len_max; \ + ctype* restrict p_edge = p_use + (i )*1; \ +\ + PASTEMAC(ch,set0s_mxn) \ + ( \ + m_edge, \ + n_edge, \ + p_edge, 1, ldp \ + ); \ + } \ +\ + /* If panel_len < panel_len_max, then we zero those unused columns. */ \ + if ( panel_len < panel_len_max ) \ + { \ + const dim_t j = panel_len; \ + const dim_t m_edge = panel_dim_max; \ + const dim_t n_edge = panel_len_max - panel_len; \ + ctype* restrict p_edge = p_use + (j )*ldp; \ +\ + PASTEMAC(ch,set0s_mxn) \ + ( \ + m_edge, \ + n_edge, \ + p_edge, 1, ldp \ + ); \ + } \ + } \ +\ +/* +if ( !row_stored ) \ +PASTEMAC(ch,fprintm)( stdout, "packm_var1: a packed", panel_dim_max, panel_len_max, \ + p_use, rs_p, cs_p, "%5.2f", "" ); \ +else \ +PASTEMAC(ch,fprintm)( stdout, "packm_var1: b packed", panel_len_max, panel_dim_max, \ + p_use, rs_p, cs_p, "%5.2f", "" ); \ +*/ \ +\ + p_begin += ps_p; \ + } \ +} + +//INSERT_GENTFUNC_BASIC0( packm_var1 ) +GENTFUNC( float, s, packm_var2 ) +GENTFUNC( double, d, packm_var2 ) +GENTFUNC( scomplex, c, packm_var2 ) +GENTFUNC( dcomplex, z, packm_var2 ) + diff --git a/addon/gemmd/bao_packm_cxk.c b/addon/gemmd/bao_packm_cxk.c new file mode 100644 index 0000000000..645f09d798 --- /dev/null +++ b/addon/gemmd/bao_packm_cxk.c @@ -0,0 +1,199 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, opname ) \ +\ +void PASTECH2(bao_,ch,opname) \ + ( \ + conj_t conja, \ + pack_t schema, \ + dim_t panel_dim, \ + dim_t panel_dim_max, \ + dim_t panel_len, \ + dim_t panel_len_max, \ + ctype* kappa, \ + ctype* d, inc_t incd, \ + ctype* a, inc_t inca, inc_t lda, \ + ctype* p, inc_t ldp, \ + cntx_t* cntx \ + ) \ +{ \ + /* Note that we use panel_dim_max, not panel_dim, to query the packm + kernel function pointer. This means that we always use the same + kernel, even for edge cases. */ \ + num_t dt = PASTEMAC(ch,type); \ + l1mkr_t ker_id = panel_dim_max; \ +\ + PASTECH2(ch,opname,_ker_ft) f; \ +\ + /* Query the context for the packm kernel corresponding to the current + panel dimension, or kernel id. If the id is invalid, the function will + return NULL. */ \ + f = bli_cntx_get_packm_ker_dt( dt, ker_id, cntx ); \ +\ + /* If there exists a kernel implementation for the micro-panel dimension + provided, we invoke the implementation. Otherwise, we use scal2m. */ \ + /* NOTE: We've disabled calling packm micro-kernels from the context for + this implementation. To re-enable, change FALSE to TRUE in the + conditional below. */ \ + if ( f != NULL && FALSE ) \ + { \ + f \ + ( \ + conja, \ + schema, \ + panel_dim, \ + panel_len, \ + panel_len_max, \ + kappa, \ + a, inca, lda, \ + p, ldp, \ + cntx \ + ); \ + } \ + else \ + { \ + /* NOTE: We assume here that kappa = 1 and therefore ignore it. If + we're wrong, this will get someone's attention. */ \ + if ( !PASTEMAC(ch,eq1)( *kappa ) ) \ + bli_abort(); \ +\ + if ( d == NULL ) \ + { \ + /* Perform the packing, taking conja into account. */ \ + if ( bli_is_conj( conja ) ) \ + { \ + for ( dim_t l = 0; l < panel_len; ++l ) \ + { \ + for ( dim_t i = 0; i < panel_dim; ++i ) \ + { \ + ctype* ali = a + (l )*lda + (i )*inca; \ + ctype* pli = p + (l )*ldp + (i )*1; \ +\ + PASTEMAC(ch,copyjs)( *ali, *pli ); \ + } \ + } \ + } \ + else \ + { \ + for ( dim_t l = 0; l < panel_len; ++l ) \ + { \ + for ( dim_t i = 0; i < panel_dim; ++i ) \ + { \ + ctype* ali = a + (l )*lda + (i )*inca; \ + ctype* pli = p + (l )*ldp + (i )*1; \ +\ + PASTEMAC(ch,copys)( *ali, *pli ); \ + } \ + } \ + } \ + } \ + else /* if ( d != NULL ) */ \ + { \ + /* Perform the packing, taking conja into account. */ \ + if ( bli_is_conj( conja ) ) \ + { \ + for ( dim_t l = 0; l < panel_len; ++l ) \ + { \ + for ( dim_t i = 0; i < panel_dim; ++i ) \ + { \ + ctype* ali = a + (l )*lda + (i )*inca; \ + ctype* dl = d + (l )*incd; \ + ctype* pli = p + (l )*ldp + (i )*1; \ +\ + /* Note that ali must be the second operand here since + that is what is conjugated by scal2js. */ \ + PASTEMAC(ch,scal2js)( *dl, *ali, *pli ); \ + } \ + } \ + } \ + else \ + { \ + for ( dim_t l = 0; l < panel_len; ++l ) \ + { \ + for ( dim_t i = 0; i < panel_dim; ++i ) \ + { \ + ctype* ali = a + (l )*lda + (i )*inca; \ + ctype* dl = d + (l )*incd; \ + ctype* pli = p + (l )*ldp + (i )*1; \ +\ + PASTEMAC(ch,scal2s)( *ali, *dl, *pli ); \ + } \ + } \ + } \ + } \ +\ + /* If panel_dim < panel_dim_max, then we zero those unused rows. */ \ + if ( panel_dim < panel_dim_max ) \ + { \ + const dim_t i = panel_dim; \ + const dim_t m_edge = panel_dim_max - panel_dim; \ + const dim_t n_edge = panel_len_max; \ + ctype* restrict p_edge = p + (i )*1; \ +\ + PASTEMAC(ch,set0s_mxn) \ + ( \ + m_edge, \ + n_edge, \ + p_edge, 1, ldp \ + ); \ + } \ +\ + /* If panel_len < panel_len_max, then we zero those unused columns. */ \ + if ( panel_len < panel_len_max ) \ + { \ + const dim_t j = panel_len; \ + const dim_t m_edge = panel_dim_max; \ + const dim_t n_edge = panel_len_max - panel_len; \ + ctype* restrict p_edge = p + (j )*ldp; \ +\ + PASTEMAC(ch,set0s_mxn) \ + ( \ + m_edge, \ + n_edge, \ + p_edge, 1, ldp \ + ); \ + } \ + } \ +} + +//INSERT_GENTFUNC_BASIC0( packm_cxk ) +GENTFUNC( float, s, packm_cxk ) +GENTFUNC( double, d, packm_cxk ) +GENTFUNC( scomplex, c, packm_cxk ) +GENTFUNC( dcomplex, z, packm_cxk ) + diff --git a/addon/gemmd/bao_packm_cxk.h b/addon/gemmd/bao_packm_cxk.h new file mode 100644 index 0000000000..3e977a7cc2 --- /dev/null +++ b/addon/gemmd/bao_packm_cxk.h @@ -0,0 +1,59 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + + +#undef GENTPROT +#define GENTPROT( ctype, ch, varname ) \ +\ +void PASTECH2(bao_,ch,varname) \ + ( \ + conj_t conja, \ + pack_t schema, \ + dim_t panel_dim, \ + dim_t panel_dim_max, \ + dim_t panel_len, \ + dim_t panel_len_max, \ + ctype* kappa, \ + ctype* d, inc_t incd, \ + ctype* a, inc_t inca, inc_t lda, \ + ctype* p, inc_t ldp, \ + cntx_t* cntx \ + ); + +//INSERT_GENTPROT_BASIC0( packm_cxk ) +GENTPROT( float, s, packm_cxk ) +GENTPROT( double, d, packm_cxk ) +GENTPROT( scomplex, c, packm_cxk ) +GENTPROT( dcomplex, z, packm_cxk ) + diff --git a/addon/gemmd/gemmd.h b/addon/gemmd/gemmd.h new file mode 100644 index 0000000000..cab61bd181 --- /dev/null +++ b/addon/gemmd/gemmd.h @@ -0,0 +1,54 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name of copyright holder(s) nor the names + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef GEMMD_H +#define GEMMD_H + +// This header should contain (or #include) any definitions that must be +// folded into blis.h. + +#include "bao_gemmd.h" +#include "bao_gemmd_check.h" +#include "bao_gemmd_var.h" + +#include "bao_l3_packm_a.h" +#include "bao_l3_packm_b.h" +#include "bao_l3_packm_var.h" + +#include "bao_packm_cxk.h" + +#include "bao_l3_decor.h" + + +#endif diff --git a/addon/gemmd/thread/bao_l3_decor.h b/addon/gemmd/thread/bao_l3_decor.h new file mode 100644 index 0000000000..b4fd2b9b76 --- /dev/null +++ b/addon/gemmd/thread/bao_l3_decor.h @@ -0,0 +1,75 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef BLIS_SBX_L3_DECOR_H +#define BLIS_SBX_L3_DECOR_H + +// -- sup definitions ---------------------------------------------------------- + +// Level-3 sup internal function type. +typedef void (*l3sbxint_t) + ( + obj_t* alpha, + obj_t* a, + obj_t* d, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + rntm_t* rntm, + thrinfo_t* thread + ); + +// Level-3 sup thread decorator prototype. +void bao_l3_thread_decorator + ( + l3sbxint_t func, + opid_t family, + obj_t* alpha, + obj_t* a, + obj_t* d, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + rntm_t* rntm + ); + +// Include definitions specific to the method of multithreading. +#include "bao_l3_decor_single.h" +#include "bao_l3_decor_openmp.h" +#include "bao_l3_decor_pthreads.h" + +#endif + diff --git a/addon/gemmd/thread/bao_l3_decor_openmp.c b/addon/gemmd/thread/bao_l3_decor_openmp.c new file mode 100644 index 0000000000..1aca8de275 --- /dev/null +++ b/addon/gemmd/thread/bao_l3_decor_openmp.c @@ -0,0 +1,140 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#ifdef BLIS_ENABLE_OPENMP + +// Define a dummy thread entry function, which is needed in the pthreads +// version, so that when building Windows DLLs (with OpenMP enabled or with +// no multithreading) we don't risk having an unresolved symbol. +void* bao_l3_thread_entry( void* data_void ) { return NULL; } + +//#define PRINT_THRINFO + +void bao_l3_thread_decorator + ( + l3sbxint_t func, + opid_t family, + obj_t* alpha, + obj_t* a, + obj_t* d, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + rntm_t* rntm + ) +{ + // Query the total number of threads from the rntm_t object. + const dim_t n_threads = bli_rntm_num_threads( rntm ); + + // NOTE: The sba was initialized in bli_init(). + + // Check out an array_t from the small block allocator. This is done + // with an internal lock to ensure only one application thread accesses + // the sba at a time. bli_sba_checkout_array() will also automatically + // resize the array_t, if necessary. + array_t* restrict array = bli_sba_checkout_array( n_threads ); + + // Access the pool_t* for thread 0 and embed it into the rntm. We do + // this up-front only so that we have the rntm_t.sba_pool field + // initialized and ready for the global communicator creation below. + bli_sba_rntm_set_pool( 0, array, rntm ); + + // Set the packing block allocator field of the rntm. This will be + // inherited by all of the child threads when they make local copies of + // the rntm below. + bli_pba_rntm_set_pba( rntm ); + + // Allcoate a global communicator for the root thrinfo_t structures. + thrcomm_t* restrict gl_comm = bli_thrcomm_create( rntm, n_threads ); + + + _Pragma( "omp parallel num_threads(n_threads)" ) + { + // Create a thread-local copy of the master thread's rntm_t. This is + // necessary since we want each thread to be able to track its own + // small block pool_t as it executes down the function stack. + rntm_t rntm_l = *rntm; + rntm_t* restrict rntm_p = &rntm_l; + + // Query the thread's id from OpenMP. + const dim_t tid = omp_get_thread_num(); + + // Check for a somewhat obscure OpenMP thread-mistmatch issue. + // NOTE: This calls the same function used for the conventional/large + // code path. + bli_l3_thread_decorator_thread_check( n_threads, tid, gl_comm, rntm_p ); + + // Use the thread id to access the appropriate pool_t* within the + // array_t, and use it to set the sba_pool field within the rntm_t. + // If the pool_t* element within the array_t is NULL, it will first + // be allocated/initialized. + bli_sba_rntm_set_pool( tid, array, rntm_p ); + + thrinfo_t* thread = NULL; + + // Create the root node of the thread's thrinfo_t structure. + bli_l3_sup_thrinfo_create_root( tid, gl_comm, rntm_p, &thread ); + + func + ( + alpha, + a, + d, + b, + beta, + c, + cntx, + rntm_p, + thread + ); + + // Free the current thread's thrinfo_t structure. + bli_l3_sup_thrinfo_free( rntm_p, thread ); + } + + // We shouldn't free the global communicator since it was already freed + // by the global communicator's chief thread in bli_l3_thrinfo_free() + // (called from the thread entry function). + + // Check the array_t back into the small block allocator. Similar to the + // check-out, this is done using a lock embedded within the sba to ensure + // mutual exclusion. + bli_sba_checkin_array( array ); +} + +#endif + diff --git a/addon/gemmd/thread/bao_l3_decor_openmp.h b/addon/gemmd/thread/bao_l3_decor_openmp.h new file mode 100644 index 0000000000..9c956d7c36 --- /dev/null +++ b/addon/gemmd/thread/bao_l3_decor_openmp.h @@ -0,0 +1,44 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef BLIS_SBX_L3_DECOR_OPENMP_H +#define BLIS_SBX_L3_DECOR_OPENMP_H + +// Definitions specific to situations when OpenMP multithreading is enabled. +#ifdef BLIS_ENABLE_OPENMP + +#endif + +#endif + diff --git a/addon/gemmd/thread/bao_l3_decor_pthreads.c b/addon/gemmd/thread/bao_l3_decor_pthreads.c new file mode 100644 index 0000000000..587b8400f1 --- /dev/null +++ b/addon/gemmd/thread/bao_l3_decor_pthreads.c @@ -0,0 +1,220 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#ifdef BLIS_ENABLE_PTHREADS + +// A data structure to assist in passing operands to additional threads. +typedef struct thread_data +{ + l3sbxint_t func; + opid_t family; + obj_t* alpha; + obj_t* a; + obj_t* d; + obj_t* b; + obj_t* beta; + obj_t* c; + cntx_t* cntx; + rntm_t* rntm; + dim_t tid; + thrcomm_t* gl_comm; + array_t* array; +} thread_data_t; + +// Entry point function for additional threads. +void* bao_l3_thread_entry( void* data_void ) +{ + thread_data_t* data = data_void; + + l3sbxint_t func = data->func; + opid_t family = data->family; + obj_t* alpha = data->alpha; + obj_t* a = data->a; + obj_t* d = data->d; + obj_t* b = data->b; + obj_t* beta = data->beta; + obj_t* c = data->c; + cntx_t* cntx = data->cntx; + rntm_t* rntm = data->rntm; + dim_t tid = data->tid; + array_t* array = data->array; + thrcomm_t* gl_comm = data->gl_comm; + + ( void )family; + + // Create a thread-local copy of the master thread's rntm_t. This is + // necessary since we want each thread to be able to track its own + // small block pool_t as it executes down the function stack. + rntm_t rntm_l = *rntm; + rntm_t* restrict rntm_p = &rntm_l; + + // Use the thread id to access the appropriate pool_t* within the + // array_t, and use it to set the sba_pool field within the rntm_t. + // If the pool_t* element within the array_t is NULL, it will first + // be allocated/initialized. + bli_sba_rntm_set_pool( tid, array, rntm_p ); + + thrinfo_t* thread = NULL; + + // Create the root node of the current thread's thrinfo_t structure. + bli_l3_sup_thrinfo_create_root( tid, gl_comm, rntm_p, &thread ); + + func + ( + alpha, + a, + d, + b, + beta, + c, + cntx, + rntm_p, + thread + ); + + // Free the current thread's thrinfo_t structure. + bli_l3_sup_thrinfo_free( rntm_p, thread ); + + return NULL; +} + +void bao_l3_thread_decorator + ( + l3sbxint_t func, + opid_t family, + obj_t* alpha, + obj_t* a, + obj_t* d, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + rntm_t* rntm + ) +{ + err_t r_val; + + // Query the total number of threads from the context. + const dim_t n_threads = bli_rntm_num_threads( rntm ); + + // NOTE: The sba was initialized in bli_init(). + + // Check out an array_t from the small block allocator. This is done + // with an internal lock to ensure only one application thread accesses + // the sba at a time. bli_sba_checkout_array() will also automatically + // resize the array_t, if necessary. + array_t* restrict array = bli_sba_checkout_array( n_threads ); + + // Access the pool_t* for thread 0 and embed it into the rntm. We do + // this up-front only so that we have the rntm_t.sba_pool field + // initialized and ready for the global communicator creation below. + bli_sba_rntm_set_pool( 0, array, rntm ); + + // Set the packing block allocator field of the rntm. This will be + // inherited by all of the child threads when they make local copies of + // the rntm below. + bli_pba_rntm_set_pba( rntm ); + + // Allocate a global communicator for the root thrinfo_t structures. + thrcomm_t* restrict gl_comm = bli_thrcomm_create( rntm, n_threads ); + + // Allocate an array of pthread objects and auxiliary data structs to pass + // to the thread entry functions. + + #ifdef BLIS_ENABLE_MEM_TRACING + printf( "bli_l3_thread_decorator().pth: " ); + #endif + bli_pthread_t* pthreads = bli_malloc_intl( sizeof( bli_pthread_t ) * n_threads, &r_val ); + + #ifdef BLIS_ENABLE_MEM_TRACING + printf( "bli_l3_thread_decorator().pth: " ); + #endif + thread_data_t* datas = bli_malloc_intl( sizeof( thread_data_t ) * n_threads, &r_val ); + + // NOTE: We must iterate backwards so that the chief thread (thread id 0) + // can spawn all other threads before proceeding with its own computation. + for ( dim_t tid = n_threads - 1; 0 <= tid; tid-- ) + { + // Set up thread data for additional threads (beyond thread 0). + datas[tid].func = func; + datas[tid].family = family; + datas[tid].alpha = alpha; + datas[tid].a = a; + datas[tid].d = d; + datas[tid].b = b; + datas[tid].beta = beta; + datas[tid].c = c; + datas[tid].cntx = cntx; + datas[tid].rntm = rntm; + datas[tid].tid = tid; + datas[tid].gl_comm = gl_comm; + datas[tid].array = array; + + // Spawn additional threads for ids greater than 1. + if ( tid != 0 ) + bli_pthread_create( &pthreads[tid], NULL, &bao_l3_thread_entry, &datas[tid] ); + else + bao_l3_thread_entry( ( void* )(&datas[0]) ); + } + + // We shouldn't free the global communicator since it was already freed + // by the global communicator's chief thread in bli_l3_thrinfo_free() + // (called from the thread entry function). + + // Thread 0 waits for additional threads to finish. + for ( dim_t tid = 1; tid < n_threads; tid++ ) + { + bli_pthread_join( pthreads[tid], NULL ); + } + + // Check the array_t back into the small block allocator. Similar to the + // check-out, this is done using a lock embedded within the sba to ensure + // mutual exclusion. + bli_sba_checkin_array( array ); + + #ifdef BLIS_ENABLE_MEM_TRACING + printf( "bli_l3_thread_decorator().pth: " ); + #endif + bli_free_intl( pthreads ); + + #ifdef BLIS_ENABLE_MEM_TRACING + printf( "bli_l3_thread_decorator().pth: " ); + #endif + bli_free_intl( datas ); +} + +#endif + diff --git a/addon/gemmd/thread/bao_l3_decor_pthreads.h b/addon/gemmd/thread/bao_l3_decor_pthreads.h new file mode 100644 index 0000000000..69adec45ee --- /dev/null +++ b/addon/gemmd/thread/bao_l3_decor_pthreads.h @@ -0,0 +1,47 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef BLIS_SBX_L3_DECOR_PTHREADS_H +#define BLIS_SBX_L3_DECOR_PTHREADS_H + +// Definitions specific to situations when POSIX multithreading is enabled. +#ifdef BLIS_ENABLE_PTHREADS + +// Thread entry point prototype. +void* bao_l3_thread_entry( void* data_void ); + +#endif + +#endif + diff --git a/addon/gemmd/thread/bao_l3_decor_single.c b/addon/gemmd/thread/bao_l3_decor_single.c new file mode 100644 index 0000000000..d60891d65b --- /dev/null +++ b/addon/gemmd/thread/bao_l3_decor_single.c @@ -0,0 +1,143 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#ifndef BLIS_ENABLE_MULTITHREADING + +#define SKIP_THRINFO_TREE + +void bao_l3_thread_decorator + ( + l3sbxint_t func, + opid_t family, + //pack_t schema_a, + //pack_t schema_b, + obj_t* alpha, + obj_t* a, + obj_t* d, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + rntm_t* rntm + ) +{ + // For sequential execution, we use only one thread. + const dim_t n_threads = 1; + + // NOTE: The sba was initialized in bli_init(). + + // Check out an array_t from the small block allocator. This is done + // with an internal lock to ensure only one application thread accesses + // the sba at a time. bli_sba_checkout_array() will also automatically + // resize the array_t, if necessary. + array_t* restrict array = bli_sba_checkout_array( n_threads ); + + // Access the pool_t* for thread 0 and embed it into the rntm. + bli_sba_rntm_set_pool( 0, array, rntm ); + + // Set the packing block allocator field of the rntm. + bli_pba_rntm_set_pba( rntm ); + +#ifndef SKIP_THRINFO_TREE + // Allcoate a global communicator for the root thrinfo_t structures. + thrcomm_t* restrict gl_comm = bli_thrcomm_create( rntm, n_threads ); +#endif + + + { + // NOTE: We don't need to create another copy of the rntm_t since + // it was already copied in one of the high-level oapi functions. + rntm_t* restrict rntm_p = rntm; + + // There is only one thread id (for the thief thread). + const dim_t tid = 0; + + // Use the thread id to access the appropriate pool_t* within the + // array_t, and use it to set the sba_pool field within the rntm_t. + // If the pool_t* element within the array_t is NULL, it will first + // be allocated/initialized. + // NOTE: This is commented out because, in the single-threaded case, + // this is redundant since it's already been done above. + //bli_sba_rntm_set_pool( tid, array, rntm_p ); + +#ifndef SKIP_THRINFO_TREE + thrinfo_t* thread = NULL; + + // Create the root node of the thread's thrinfo_t structure. + bli_l3_sup_thrinfo_create_root( tid, gl_comm, rntm_p, &thread ); +#else + // This optimization allows us to use one of the global thrinfo_t + // objects for single-threaded execution rather than grow one from + // scratch. The key is that bli_thrinfo_sup_grow(), which is called + // from within the variants, will immediately return if it detects + // that the thrinfo_t* passed into it is either + // &BLIS_GEMM_SINGLE_THREADED or &BLIS_PACKM_SINGLE_THREADED. + thrinfo_t* thread = &BLIS_GEMM_SINGLE_THREADED; + + ( void )tid; +#endif + + func + ( + alpha, + a, + d, + b, + beta, + c, + cntx, + rntm_p, + thread + ); + +#ifndef SKIP_THRINFO_TREE + // Free the current thread's thrinfo_t structure. + bli_l3_sup_thrinfo_free( rntm_p, thread ); +#endif + } + + // We shouldn't free the global communicator since it was already freed + // by the global communicator's chief thread in bli_l3_thrinfo_free() + // (called above). + + // Check the array_t back into the small block allocator. Similar to the + // check-out, this is done using a lock embedded within the sba to ensure + // mutual exclusion. + bli_sba_checkin_array( array ); +} + +#endif + diff --git a/addon/gemmd/thread/bao_l3_decor_single.h b/addon/gemmd/thread/bao_l3_decor_single.h new file mode 100644 index 0000000000..211a43a894 --- /dev/null +++ b/addon/gemmd/thread/bao_l3_decor_single.h @@ -0,0 +1,44 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef BLIS_SBX_L3_DECOR_SINGLE_H +#define BLIS_SBX_L3_DECOR_SINGLE_H + +// Definitions specific to situations when multithreading is disabled. +#ifndef BLIS_ENABLE_MULTITHREADING + +#endif + +#endif + diff --git a/bench/CMakeLists.txt b/bench/CMakeLists.txt new file mode 100644 index 0000000000..00d01fdd21 --- /dev/null +++ b/bench/CMakeLists.txt @@ -0,0 +1,97 @@ +##Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved.## + +add_definitions(-DBLAS="AOCL") +add_definitions(-DN_REPEAT=1000) +add_definitions(-DINT_FS="%lld") +add_definitions(-DUINT_FS="%llu") + +add_executable(BenchAmaxv bench_amaxv.c) +target_link_libraries(BenchAmaxv debug "${LIB_NAME}.lib") +if(ENABLE_OPENMP) + target_link_libraries(BenchAmaxv OpenMP::OpenMP_CXX) +endif() +target_link_libraries(BenchAmaxv optimized "${LIB_NAME}.lib") + +add_executable(BenchAxpbyv bench_axpbyv.c) +target_link_libraries(BenchAxpbyv debug "${LIB_NAME}.lib") +if(ENABLE_OPENMP) + target_link_libraries(BenchAxpbyv OpenMP::OpenMP_CXX) +endif() +target_link_libraries(BenchAxpbyv optimized "${LIB_NAME}.lib") + +add_executable(BenchCopyv bench_copyv.c) +target_link_libraries(BenchCopyv debug "${LIB_NAME}.lib") +if(ENABLE_OPENMP) + target_link_libraries(BenchCopyv OpenMP::OpenMP_CXX) +endif() +target_link_libraries(BenchCopyv optimized "${LIB_NAME}.lib") + +add_executable(BenchDotv bench_dotv.c) +target_link_libraries(BenchDotv debug "${LIB_NAME}.lib") +if(ENABLE_OPENMP) + target_link_libraries(BenchDotv OpenMP::OpenMP_CXX) +endif() +target_link_libraries(BenchDotv optimized "${LIB_NAME}.lib") + +add_executable(BenchGemm bench_gemm.c) +target_link_libraries(BenchGemm debug "${LIB_NAME}.lib") +if(ENABLE_OPENMP) + target_link_libraries(BenchGemm OpenMP::OpenMP_CXX) +endif() +target_link_libraries(BenchGemm optimized "${LIB_NAME}.lib") + +add_executable(BenchGemmt bench_gemmt.c) +target_link_libraries(BenchGemmt debug "${LIB_NAME}.lib") +if(ENABLE_OPENMP) + target_link_libraries(BenchGemmt OpenMP::OpenMP_CXX) +endif() +target_link_libraries(BenchGemmt optimized "${LIB_NAME}.lib") + +add_executable(BenchGemv bench_gemv.c) +target_link_libraries(BenchGemv debug "${LIB_NAME}.lib") +if(ENABLE_OPENMP) + target_link_libraries(BenchGemv OpenMP::OpenMP_CXX) +endif() +target_link_libraries(BenchGemv optimized "${LIB_NAME}.lib") + +add_executable(BenchGer bench_ger.c) +target_link_libraries(BenchGer debug "${LIB_NAME}.lib") +if(ENABLE_OPENMP) + target_link_libraries(BenchGer OpenMP::OpenMP_CXX) +endif() +target_link_libraries(BenchGer optimized "${LIB_NAME}.lib") + +add_executable(BenchScalv bench_scalv.c) +target_link_libraries(BenchScalv debug "${LIB_NAME}.lib") +if(ENABLE_OPENMP) + target_link_libraries(BenchScalv OpenMP::OpenMP_CXX) +endif() +target_link_libraries(BenchScalv optimized "${LIB_NAME}.lib") + +add_executable(BenchSwapv bench_swapv.c) +target_link_libraries(BenchSwapv debug "${LIB_NAME}.lib") +if(ENABLE_OPENMP) + target_link_libraries(BenchSwapv OpenMP::OpenMP_CXX) +endif() +target_link_libraries(BenchSwapv optimized "${LIB_NAME}.lib") + +add_executable(BenchSyrk bench_syrk.c) +target_link_libraries(BenchSyrk debug "${LIB_NAME}.lib") +if(ENABLE_OPENMP) + target_link_libraries(BenchSyrk OpenMP::OpenMP_CXX) +endif() +target_link_libraries(BenchSyrk optimized "${LIB_NAME}.lib") + +add_executable(BenchTrsm bench_trsm.c) +target_link_libraries(BenchTrsm debug "${LIB_NAME}.lib") +if(ENABLE_OPENMP) + target_link_libraries(BenchTrsm OpenMP::OpenMP_CXX) +endif() +target_link_libraries(BenchTrsm optimized "${LIB_NAME}.lib") + +add_executable(BenchTrsv bench_trsv.c) +target_link_libraries(BenchTrsv debug "${LIB_NAME}.lib") +if(ENABLE_OPENMP) + target_link_libraries(BenchTrsv OpenMP::OpenMP_CXX) +endif() +target_link_libraries(BenchTrsv optimized "${LIB_NAME}.lib") diff --git a/bench/Makefile b/bench/Makefile index d47485b2fc..0203d5a5b0 100755 --- a/bench/Makefile +++ b/bench/Makefile @@ -6,7 +6,7 @@ # libraries. # # Copyright (C) 2014, The University of Texas at Austin -# Copyright (C) 2017 - 2021, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2017 - 2022, Advanced Micro Devices, Inc. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are @@ -186,6 +186,7 @@ blis: \ bench_gemv_blis.x \ bench_syrk_blis.x \ bench_ger_blis.x \ + bench_nrm2_blis.x \ bench_scalv_blis.x \ bench_dotv_blis.x \ bench_trsv_blis.x \ @@ -201,6 +202,7 @@ openblas: \ bench_gemv_openblas.x \ bench_syrk_openblas.x \ bench_ger_openblas.x \ + bench_nrm2_openblas.x \ bench_scalv_openblas.x \ bench_dotv_openblas.x \ bench_trsv_openblas.x \ @@ -231,6 +233,7 @@ mkl: \ bench_gemv_mkl.x \ bench_syrk_mkl.x \ bench_ger_mkl.x \ + bench_nrm2_mkl.x \ bench_scalv_mkl.x \ bench_dotv_mkl.x \ bench_trsv_mkl.x \ @@ -246,17 +249,17 @@ $(TEST_OBJ_PATH)/%.o: $(TEST_SRC_PATH)/%.c $(CC) $(CFLAGS) -c $< -o $@ bench_%_openblas.o: bench_%.c - $(CC) $(CFLAGS) -DBLAS=\"openblas\" $(NRTS) -c $< -o $@ + $(CC) $(CFLAGS) -DBLAS=\"openblas\" $(NRTS) -DINT_FS=\"%ld\" -DUINT_FS=\"%lu\" -c $< -o $@ bench_%_atlas.o: bench_%.c - $(CC) $(CFLAGS) -DBLAS=\"atlas\" $(NRTS) -c $< -o $@ + $(CC) $(CFLAGS) -DBLAS=\"atlas\" $(NRTS) -DINT_FS=\"%ld\" -DUINT_FS=\"%lu\" -c $< -o $@ bench_%_mkl.o: bench_%.c - $(CC) $(CFLAGS) -DBLAS=\"mkl\" $(NRTS) -c $< -o $@ + $(CC) $(CFLAGS) -DBLAS=\"mkl\" $(NRTS) -DINT_FS=\"%ld\" -DUINT_FS=\"%lu\" -c $< -o $@ bench_%_blis.o: bench_%.c - $(CC) $(CFLAGS) -DBLAS=\"aocl\" $(NRTS) -c $< -o $@ + $(CC) $(CFLAGS) -DBLAS=\"aocl\" $(NRTS) -DINT_FS=\"%ld\" -DUINT_FS=\"%lu\" -c $< -o $@ # -- Executable file rules -- diff --git a/bench/bench_amaxv.c b/bench/bench_amaxv.c index 739bd0f979..2a0e578975 100644 --- a/bench/bench_amaxv.c +++ b/bench/bench_amaxv.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2021-2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -101,8 +101,7 @@ int main( int argc, char** argv ) char tmp[256]; // to store function name, line no present in logs. // {S,D,C,Z} {n incx} - - while (fscanf(fin, "%s %c %ld %ld \n", + while (fscanf(fin, "%s %c " INT_FS INT_FS " \n", tmp, &dt_ch, &n, &incx) == 4) { diff --git a/bench/bench_aocl_gemm/Makefile b/bench/bench_aocl_gemm/Makefile new file mode 100755 index 0000000000..91b3a7b587 --- /dev/null +++ b/bench/bench_aocl_gemm/Makefile @@ -0,0 +1,132 @@ +# +# +# BLIS +# An object-based framework for developing high-performance BLAS-like +# libraries. +# +# Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# - Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# - Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# - Neither the name(s) of the copyright holder(s) nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# + +# Makefile for lpgemm bench. +# + +# +# --- Makefile PHONY target definitions ---------------------------------------- +# + +.PHONY: all \ + blis \ + check-env check-env-mk check-lib \ + clean cleanx + +# Comments: +# - DIST_PATH is assumed to not exist if BLIS_INSTALL_PATH is given. +# - We must use recursively expanded assignment for LIB_PATH and INC_PATH in +# the second case because CONFIG_NAME is not yet set. +ifneq ($(strip $(BLIS_INSTALL_PATH)),) +LIB_PATH := $(BLIS_INSTALL_PATH)/lib +INC_PATH := $(BLIS_INSTALL_PATH)/include/blis +SHARE_PATH := $(BLIS_INSTALL_PATH)/share/blis +else +DIST_PATH := ../.. +LIB_PATH = ../../lib/$(CONFIG_NAME) +INC_PATH = ../../include/$(CONFIG_NAME) +SHARE_PATH := ../.. +endif + + + +# +# --- Include common makefile definitions -------------------------------------- +# + +# Include the common makefile fragment. +-include $(SHARE_PATH)/common.mk + +# +# --- General build definitions ------------------------------------------------ +# + +TEST_SRC_PATH := . +TEST_OBJ_PATH := . + +# Gather all local object files. +TEST_OBJS := $(patsubst $(TEST_SRC_PATH)/%.c, \ + $(TEST_OBJ_PATH)/%.o, \ + $(wildcard $(TEST_SRC_PATH)/*.c)) + + + +# Override the value of CINCFLAGS so that the value of CFLAGS returned by +# get-user-cflags-for() is not cluttered up with include paths needed only +# while building BLIS. +CINCFLAGS := -I$(INC_PATH) -I$(CBLAS_HEADER_PATH) + +# Use the CFLAGS for the configuration family. +CFLAGS := $(call get-user-cflags-for,$(CONFIG_NAME)) + +# Add local header paths to CFLAGS +CFLAGS += -I$(TEST_SRC_PATH) + +# Locate the libblis library to which we will link. +#LIBBLIS_LINK := $(LIB_PATH)/$(LIBBLIS_L) + +# +# --- Targets/rules ------------------------------------------------------------ +# + +# Complete list of possible targets when defining 'all': +# +# blis openblas atlas mkl mac essl +# +all: blis + +blis: \ + bench_lpgemm_blis.x + + +# --Object file rules -- + +$(TEST_OBJ_PATH)/%.o: $(TEST_SRC_PATH)/%.c + $(CC) $(CFLAGS) -c $< -o $@ + +bench_%_blis.o: bench_%.c + $(CC) $(CFLAGS) -DBLAS=\"aocl\" $(NRTS) -DINT_FS=\"%ld\" -DUINT_FS=\"%lu\" -c $< -o $@ + + +# -- Executable file rules -- + +bench_%_blis.x: bench_%_blis.o $(LIBBLIS_LINK) + $(LINKER) $< $(LIBBLIS_LINK) $(LDFLAGS) -o $@ + + +# -- Clean rules -- + +clean: cleanx + +cleanx: + - $(RM_F) *.o *.x diff --git a/bench/bench_aocl_gemm/bench_input.txt b/bench/bench_aocl_gemm/bench_input.txt new file mode 100644 index 0000000000..d8b8226a13 --- /dev/null +++ b/bench/bench_aocl_gemm/bench_input.txt @@ -0,0 +1,783 @@ +b r r 480 20 2050 2050 20 20 +b r r 481 20 2050 2050 20 20 +b r r 482 20 2050 2050 20 20 +b r p 483 20 2050 2050 20 20 +b r R 484 20 2050 2050 20 20 +b r R 485 20 2050 2050 20 20 +b r R 480 39 2050 2050 39 39 +b r R 481 39 2050 2050 39 39 +b r R 482 39 2050 2050 39 39 +b r R 483 39 2050 2050 39 39 +b r R 484 39 2050 2050 39 39 +b r p 485 39 2050 2050 39 39 +b r p 480 50 2050 2050 50 50 +b r p 481 50 2050 2050 50 50 +b r p 482 50 2050 2050 50 50 +b r p 483 50 2050 2050 50 50 +b r p 484 50 2050 2050 50 50 +b r p 485 50 2050 2050 50 50 +b r R 480 1108 2050 2050 1108 1108 +b r R 481 1108 2050 2050 1108 1108 +b r R 482 1108 2050 2050 1108 1108 +b r R 483 1108 2050 2050 1108 1108 +b r R 484 1108 2050 2050 1108 1108 +b r R 485 1108 2050 2050 1108 1108 +b r R 480 1127 2050 2050 1127 1127 +b r R 481 1127 2050 2050 1127 1127 +b r R 482 1127 2050 2050 1127 1127 +b r R 483 1127 2050 2050 1127 1127 +b r p 484 1127 2050 2050 1127 1127 +b r p 485 1127 2050 2050 1127 1127 +b r p 480 1138 2050 2050 1138 1138 +b r p 481 1138 2050 2050 1138 1138 +b r p 482 1138 2050 2050 1138 1138 +b r p 483 1138 2050 2050 1138 1138 +b r p 484 1138 2050 2050 1138 1138 +b r p 485 1138 2050 2050 1138 1138 +b r p 1 1 3 3 1 1 +b r p 1 9 3 3 9 9 +b r p 1 2048 3 3 2048 2048 +b r p 1 2048 5192 5192 2048 2048 +b r p 9 1 3 3 1 1 +b r p 576 1 3500 3500 1 1 +b r p 1 1 1 1 1 1 +b r p 102 1088 1024 1024 1088 1088 +b r p 102 2048 1024 1024 2048 2048 +b r p 485 656 1024 1024 656 656 +b r p 483 656 1024 1024 656 656 +b r p 81 128 3 3 128 128 +b r p 1022 512 515 515 512 512 +b r p 74 512 515 515 512 512 +b r p 253 2048 515 515 2048 2048 +b r p 8192 1040 515 515 1040 1040 +b r p 10 1029 515 515 1029 1029 +b r p 24 1040 2050 2050 1040 1040 +b r p 1024 1029 2050 2050 1029 1029 +b r p 480 660 2050 2050 660 660 +b r p 481 660 2050 2050 660 660 +b r p 482 660 2050 2050 660 660 +b r p 483 660 2050 2050 660 660 +b r p 484 660 2050 2050 660 660 +b r p 485 660 2050 2050 660 660 +b r p 480 679 2050 2050 679 679 +b r p 481 679 2050 2050 679 679 +b r p 482 679 2050 2050 679 679 +b r p 483 679 2050 2050 679 679 +b r p 484 679 2050 2050 679 679 +b r p 485 679 2050 2050 679 679 +b r p 480 690 2050 2050 690 690 +b r p 481 690 2050 2050 690 690 +b r p 482 690 2050 2050 690 690 +b r p 483 690 2050 2050 690 690 +b r p 484 690 2050 2050 690 690 +b r p 485 690 2050 2050 690 690 +b r p 480 660 2048 2048 660 660 +b r p 481 660 2048 2048 660 660 +b r p 482 660 2048 2048 660 660 +b r p 483 660 2048 2048 660 660 +b r p 484 660 2048 2048 660 660 +b r p 485 660 2048 2048 660 660 +b r p 480 679 2048 2048 679 679 +b r p 481 679 2048 2048 679 679 +b r p 482 679 2048 2048 679 679 +b r p 483 679 2048 2048 679 679 +b r p 484 679 2048 2048 679 679 +b r p 485 679 2048 2048 679 679 +b r p 480 690 2048 2048 690 690 +b r p 481 690 2048 2048 690 690 +b r p 482 690 2048 2048 690 690 +b r p 483 690 2048 2048 690 690 +b r p 484 690 2048 2048 690 690 +b r p 485 690 2048 2048 690 690 +b r p 480 656 1024 1024 656 656 +b r p 480 128 3 3 128 128 +b r p 1024 512 515 515 512 512 +b r p 1024 2048 1024 1024 2048 2048 +b r p 1024 2048 515 515 2048 2048 +b r p 1024 1040 515 515 1040 1040 +b r p 5 1029 515 515 1029 1029 +b r p 1024 1029 515 515 1029 1029 +b r p 1024 1040 2050 2050 1040 1040 +b r p 1029 1029 2050 2050 1029 1029 +b r R 480 646 2050 2050 646 646 +b r R 481 646 2050 2050 646 646 +b r R 482 646 2050 2050 646 646 +b r R 483 646 2050 2050 646 646 +b r R 484 646 2050 2050 646 646 +b r R 485 646 2050 2050 646 646 +b r R 481 656 2050 2050 656 656 +b r R 482 656 2050 2050 656 656 +b r R 483 656 2050 2050 656 656 +b r R 484 656 2050 2050 656 656 +b r p 485 656 2050 2050 656 656 +b r p 480 672 2050 2050 672 672 +b r p 481 672 2050 2050 672 672 +b r p 482 672 2050 2050 672 672 +b r p 483 672 2050 2050 672 672 +b r p 484 672 2050 2050 672 672 +b r p 485 672 2050 2050 672 672 +b r p 480 688 2050 2050 688 688 +b r p 481 688 2050 2050 688 688 +b r r 482 688 2050 2050 688 688 +b r r 483 688 2050 2050 688 688 +b r r 484 688 2050 2050 688 688 +b r r 485 688 2050 2050 688 688 +b r r 1024 512 64 64 512 512 +b r r 16 256 512 512 256 256 +b r r 480 640 512 512 640 640 +b r r 64 768 512 512 768 768 +b r r 128 128 128 128 128 128 +b r r 1024 64 512 512 64 64 +b r r 1024 256 32 32 256 256 +b r r 1024 512 64 64 512 512 +b r r 480 640 512 512 640 640 +b r p 1024 32 256 256 32 32 +b r P 1024 64 512 512 64 64 +b r P 64 800 320 320 800 800 +b r P 64 768 512 512 768 768 +b r P 16 256 512 512 256 256 +b r P 128 128 128 128 128 128 +b r P 256 512 256 256 512 512 +b r P 1024 1024 1024 1024 1024 1024 +b r P 480 640 1024 1024 640 640 +b r P 480 640 256 256 640 640 +b r P 8 64 32 32 64 64 +b r P 9 64 32 32 64 64 +b r P 10 128 64 64 128 128 +b r P 8 8 8 8 8 8 +b r P 12 12 12 12 12 12 +b r P 25 25 25 25 25 25 +b r P 25 25 20 20 25 25 +b c p 485 39 2050 485 2050 485 +b c p 480 50 2050 480 2050 480 +b c p 481 50 2050 481 2050 481 +b c p 482 50 2050 482 2050 482 +b c p 483 50 2050 483 2050 483 +b c p 484 50 2050 484 2050 484 +b c p 485 50 2050 485 2050 485 +b c p 484 1127 2050 484 2050 484 +b c p 485 1127 2050 485 2050 485 +b c p 480 1138 2050 480 2050 480 +b c p 481 1138 2050 481 2050 481 +b c p 482 1138 2050 482 2050 482 +b c p 483 1138 2050 483 2050 483 +b c p 484 1138 2050 484 2050 484 +b c p 485 1138 2050 485 2050 485 +b c p 1 1 3 1 3 1 +b c p 1 9 3 1 3 1 +b c p 1 2048 3 1 3 1 +b c p 1 2048 5192 1 5192 1 +b c p 9 1 3 9 3 9 +b c p 576 1 3500 576 3500 576 +b c p 1 1 1 1 1 1 +b c p 102 1088 1024 102 1024 102 +b c p 102 2048 1024 102 1024 102 +b c p 485 656 1024 485 1024 485 +b c p 483 656 1024 483 1024 483 +b c p 81 128 3 81 3 81 +b c p 1022 512 515 1022 515 1022 +b c p 74 512 515 74 515 74 +b c p 253 2048 515 253 515 253 +b c p 8192 1040 515 8192 515 8192 +b c p 10 1029 515 10 515 10 +b c p 24 1040 2050 24 2050 24 +b c p 1024 1029 2050 1024 2050 1024 +b c p 480 660 2050 480 2050 480 +b c p 481 660 2050 481 2050 481 +b c p 482 660 2050 482 2050 482 +b c p 483 660 2050 483 2050 483 +b c p 484 660 2050 484 2050 484 +b c p 485 660 2050 485 2050 485 +b c p 480 679 2050 480 2050 480 +b c p 481 679 2050 481 2050 481 +b c p 482 679 2050 482 2050 482 +b c p 483 679 2050 483 2050 483 +b c p 484 679 2050 484 2050 484 +b c p 485 679 2050 485 2050 485 +b c p 480 690 2050 480 2050 480 +b c p 481 690 2050 481 2050 481 +b c p 482 690 2050 482 2050 482 +b c p 483 690 2050 483 2050 483 +b c p 484 690 2050 484 2050 484 +b c p 485 690 2050 485 2050 485 +b c p 480 660 2048 480 2048 480 +b c p 481 660 2048 481 2048 481 +b c p 482 660 2048 482 2048 482 +b c p 483 660 2048 483 2048 483 +b c p 484 660 2048 484 2048 484 +b c p 485 660 2048 485 2048 485 +b c p 480 679 2048 480 2048 480 +b c p 481 679 2048 481 2048 481 +b c p 482 679 2048 482 2048 482 +b c p 483 679 2048 483 2048 483 +b c p 484 679 2048 484 2048 484 +b c p 485 679 2048 485 2048 485 +b c p 480 690 2048 480 2048 480 +b c p 481 690 2048 481 2048 481 +b c p 482 690 2048 482 2048 482 +b c p 483 690 2048 483 2048 483 +b c p 484 690 2048 484 2048 484 +b c p 485 690 2048 485 2048 485 +b c p 480 656 1024 480 1024 480 +b c p 480 128 3 480 3 480 +b c p 1024 512 515 1024 515 1024 +b c p 1024 2048 1024 1024 1024 1024 +b c p 1024 2048 515 1024 515 1024 +b c p 1024 1040 515 1024 515 1024 +b c p 5 1029 515 5 515 5 +b c p 1024 1029 515 1024 515 1024 +b c p 1024 1040 2050 1024 2050 1024 +b c p 1029 1029 2050 1029 2050 1029 +b c p 485 656 2050 485 2050 485 +b c p 480 672 2050 480 2050 480 +b c p 481 672 2050 481 2050 481 +b c p 482 672 2050 482 2050 482 +b c p 483 672 2050 483 2050 483 +b c p 484 672 2050 484 2050 484 +b c p 485 672 2050 485 2050 485 +b c p 480 688 2050 480 2050 480 +b c p 481 688 2050 481 2050 481 +b c p 1024 32 256 1024 256 1024 +b c P 1024 64 512 1024 512 1024 +b c P 64 800 320 64 320 64 +b c P 64 768 512 64 512 64 +b c P 16 256 512 16 512 16 +b c P 128 128 128 128 128 128 +b c P 256 512 256 256 256 256 +b c P 1024 1024 1024 1024 1024 1024 +b c P 480 640 1024 480 1024 480 +b c P 480 640 256 480 256 480 +b c P 8 64 32 8 32 8 +b c P 9 64 32 9 32 9 +b c P 10 128 64 10 64 10 +b c P 8 8 8 8 8 8 +b c P 12 12 12 12 12 12 +b c P 25 25 25 25 25 25 +b c P 25 25 20 25 20 25 +s r r 480 20 2050 2050 20 20 +s r r 481 20 2050 2050 20 20 +s r r 482 20 2050 2050 20 20 +s r p 483 20 2050 2050 20 20 +s r R 484 20 2050 2050 20 20 +s r R 485 20 2050 2050 20 20 +s r R 480 39 2050 2050 39 39 +s r R 481 39 2050 2050 39 39 +s r R 482 39 2050 2050 39 39 +s r R 483 39 2050 2050 39 39 +s r R 484 39 2050 2050 39 39 +s r p 485 39 2050 2050 39 39 +s r p 480 50 2050 2050 50 50 +s r p 481 50 2050 2050 50 50 +s r p 482 50 2050 2050 50 50 +s r p 483 50 2050 2050 50 50 +s r p 484 50 2050 2050 50 50 +s r p 485 50 2050 2050 50 50 +s r R 480 1108 2050 2050 1108 1108 +s r R 481 1108 2050 2050 1108 1108 +s r R 482 1108 2050 2050 1108 1108 +s r R 483 1108 2050 2050 1108 1108 +s r R 484 1108 2050 2050 1108 1108 +s r R 485 1108 2050 2050 1108 1108 +s r R 480 1127 2050 2050 1127 1127 +s r R 481 1127 2050 2050 1127 1127 +s r R 482 1127 2050 2050 1127 1127 +s r R 483 1127 2050 2050 1127 1127 +s r p 484 1127 2050 2050 1127 1127 +s r p 485 1127 2050 2050 1127 1127 +s r p 480 1138 2050 2050 1138 1138 +s r p 481 1138 2050 2050 1138 1138 +s r p 482 1138 2050 2050 1138 1138 +s r p 483 1138 2050 2050 1138 1138 +s r p 484 1138 2050 2050 1138 1138 +s r p 485 1138 2050 2050 1138 1138 +s r p 1 1 3 3 1 1 +s r p 1 9 3 3 9 9 +s r p 1 2048 3 3 2048 2048 +s r p 1 2048 5192 5192 2048 2048 +s r p 9 1 3 3 1 1 +s r p 576 1 3500 3500 1 1 +s r p 1 1 1 1 1 1 +s r p 102 1088 1024 1024 1088 1088 +s r p 102 2048 1024 1024 2048 2048 +s r p 485 656 1024 1024 656 656 +s r p 483 656 1024 1024 656 656 +s r p 81 128 3 3 128 128 +s r p 1022 512 515 515 512 512 +s r p 74 512 515 515 512 512 +s r p 253 2048 515 515 2048 2048 +s r p 8192 1040 515 515 1040 1040 +s r p 10 1029 515 515 1029 1029 +s r p 24 1040 2050 2050 1040 1040 +s r p 1024 1029 2050 2050 1029 1029 +s r p 480 660 2050 2050 660 660 +s r p 481 660 2050 2050 660 660 +s r p 482 660 2050 2050 660 660 +s r p 483 660 2050 2050 660 660 +s r p 484 660 2050 2050 660 660 +s r p 485 660 2050 2050 660 660 +s r p 480 679 2050 2050 679 679 +s r p 481 679 2050 2050 679 679 +s r p 482 679 2050 2050 679 679 +s r p 483 679 2050 2050 679 679 +s r p 484 679 2050 2050 679 679 +s r p 485 679 2050 2050 679 679 +s r p 480 690 2050 2050 690 690 +s r p 481 690 2050 2050 690 690 +s r p 482 690 2050 2050 690 690 +s r p 483 690 2050 2050 690 690 +s r p 484 690 2050 2050 690 690 +s r p 485 690 2050 2050 690 690 +s r p 480 660 2048 2048 660 660 +s r p 481 660 2048 2048 660 660 +s r p 482 660 2048 2048 660 660 +s r p 483 660 2048 2048 660 660 +s r p 484 660 2048 2048 660 660 +s r p 485 660 2048 2048 660 660 +s r p 480 679 2048 2048 679 679 +s r p 481 679 2048 2048 679 679 +s r p 482 679 2048 2048 679 679 +s r p 483 679 2048 2048 679 679 +s r p 484 679 2048 2048 679 679 +s r p 485 679 2048 2048 679 679 +s r p 480 690 2048 2048 690 690 +s r p 481 690 2048 2048 690 690 +s r p 482 690 2048 2048 690 690 +s r p 483 690 2048 2048 690 690 +s r p 484 690 2048 2048 690 690 +s r p 485 690 2048 2048 690 690 +s r p 480 656 1024 1024 656 656 +s r p 480 128 3 3 128 128 +s r p 1024 512 515 515 512 512 +s r p 1024 2048 1024 1024 2048 2048 +s r p 1024 2048 515 515 2048 2048 +s r p 1024 1040 515 515 1040 1040 +s r p 5 1029 515 515 1029 1029 +s r p 1024 1029 515 515 1029 1029 +s r p 1024 1040 2050 2050 1040 1040 +s r p 1029 1029 2050 2050 1029 1029 +s r R 480 646 2050 2050 646 646 +s r R 481 646 2050 2050 646 646 +s r R 482 646 2050 2050 646 646 +s r R 483 646 2050 2050 646 646 +s r R 484 646 2050 2050 646 646 +s r R 485 646 2050 2050 646 646 +s r R 481 656 2050 2050 656 656 +s r R 482 656 2050 2050 656 656 +s r R 483 656 2050 2050 656 656 +s r R 484 656 2050 2050 656 656 +s r p 485 656 2050 2050 656 656 +s r p 480 672 2050 2050 672 672 +s r p 481 672 2050 2050 672 672 +s r p 482 672 2050 2050 672 672 +s r p 483 672 2050 2050 672 672 +s r p 484 672 2050 2050 672 672 +s r p 485 672 2050 2050 672 672 +s r p 480 688 2050 2050 688 688 +s r p 481 688 2050 2050 688 688 +s r r 482 688 2050 2050 688 688 +s r r 483 688 2050 2050 688 688 +s r r 484 688 2050 2050 688 688 +s r r 485 688 2050 2050 688 688 +s r r 1024 512 64 64 512 512 +s r r 16 256 512 512 256 256 +s r r 480 640 512 512 640 640 +s r r 64 768 512 512 768 768 +s r r 128 128 128 128 128 128 +s r r 1024 64 512 512 64 64 +s r r 1024 256 32 32 256 256 +s r r 1024 512 64 64 512 512 +s r r 480 640 512 512 640 640 +s r p 1024 32 256 256 32 32 +s r P 1024 64 512 512 64 64 +s r P 64 800 320 320 800 800 +s r P 64 768 512 512 768 768 +s r P 16 256 512 512 256 256 +s r P 128 128 128 128 128 128 +s r P 256 512 256 256 512 512 +s r P 1024 1024 1024 1024 1024 1024 +s r P 480 640 1024 1024 640 640 +s r P 480 640 256 256 640 640 +s r P 8 64 32 32 64 64 +s r P 9 64 32 32 64 64 +s r P 10 128 64 64 128 128 +s r P 8 8 8 8 8 8 +s r P 12 12 12 12 12 12 +s r P 25 25 25 25 25 25 +s r P 25 25 20 20 25 25 +i r p 480 20 2050 2050 20 20 +i r p 481 20 2050 2050 20 20 +i r p 482 20 2050 2050 20 20 +i r p 483 20 2050 2050 20 20 +i r R 484 20 2050 2050 20 20 +i r R 485 20 2050 2050 20 20 +i r R 480 39 2050 2050 39 39 +i r R 481 39 2050 2050 39 39 +i r R 482 39 2050 2050 39 39 +i r R 483 39 2050 2050 39 39 +i r R 484 39 2050 2050 39 39 +i r p 485 39 2050 2050 39 39 +i r p 480 50 2050 2050 50 50 +i r p 481 50 2050 2050 50 50 +i r p 482 50 2050 2050 50 50 +i r p 483 50 2050 2050 50 50 +i r p 484 50 2050 2050 50 50 +i r p 485 50 2050 2050 50 50 +i r R 480 1108 2050 2050 1108 1108 +i r R 481 1108 2050 2050 1108 1108 +i r R 482 1108 2050 2050 1108 1108 +i r R 483 1108 2050 2050 1108 1108 +i r R 484 1108 2050 2050 1108 1108 +i r R 485 1108 2050 2050 1108 1108 +i r R 480 1127 2050 2050 1127 1127 +i r R 481 1127 2050 2050 1127 1127 +i r R 482 1127 2050 2050 1127 1127 +i r R 483 1127 2050 2050 1127 1127 +i r p 484 1127 2050 2050 1127 1127 +i r p 485 1127 2050 2050 1127 1127 +i r p 480 1138 2050 2050 1138 1138 +i r p 481 1138 2050 2050 1138 1138 +i r p 482 1138 2050 2050 1138 1138 +i r p 483 1138 2050 2050 1138 1138 +i r p 484 1138 2050 2050 1138 1138 +i r p 485 1138 2050 2050 1138 1138 +i r p 1 1 3 3 1 1 +i r p 1 9 3 3 9 9 +i r p 1 2048 3 3 2048 2048 +i r p 1 2048 5192 5192 2048 2048 +i r p 9 1 3 3 1 1 +i r p 576 1 3500 3500 1 1 +i r p 1 1 1 1 1 1 +i r p 102 1088 1024 1024 1088 1088 +i r p 102 2048 1024 1024 2048 2048 +i r p 485 656 1024 1024 656 656 +i r p 483 656 1024 1024 656 656 +i r p 81 128 3 3 128 128 +i r p 1022 512 515 515 512 512 +i r p 74 512 515 515 512 512 +i r p 253 2048 515 515 2048 2048 +i r p 8192 1040 515 515 1040 1040 +i r p 10 1029 515 515 1029 1029 +i r p 24 1040 2050 2050 1040 1040 +i r p 1024 1029 2050 2050 1029 1029 +i r p 480 660 2050 2050 660 660 +i r p 481 660 2050 2050 660 660 +i r p 482 660 2050 2050 660 660 +i r p 483 660 2050 2050 660 660 +i r p 484 660 2050 2050 660 660 +i r p 485 660 2050 2050 660 660 +i r p 480 679 2050 2050 679 679 +i r p 481 679 2050 2050 679 679 +i r p 482 679 2050 2050 679 679 +i r p 483 679 2050 2050 679 679 +i r p 484 679 2050 2050 679 679 +i r p 485 679 2050 2050 679 679 +i r p 480 690 2050 2050 690 690 +i r p 481 690 2050 2050 690 690 +i r p 482 690 2050 2050 690 690 +i r p 483 690 2050 2050 690 690 +i r p 484 690 2050 2050 690 690 +i r p 485 690 2050 2050 690 690 +i r p 480 660 2048 2048 660 660 +i r p 481 660 2048 2048 660 660 +i r p 482 660 2048 2048 660 660 +i r p 483 660 2048 2048 660 660 +i r p 484 660 2048 2048 660 660 +i r p 485 660 2048 2048 660 660 +i r p 480 679 2048 2048 679 679 +i r p 481 679 2048 2048 679 679 +i r p 482 679 2048 2048 679 679 +i r p 483 679 2048 2048 679 679 +i r p 484 679 2048 2048 679 679 +i r p 485 679 2048 2048 679 679 +i r p 480 690 2048 2048 690 690 +i r p 481 690 2048 2048 690 690 +i r p 482 690 2048 2048 690 690 +i r p 483 690 2048 2048 690 690 +i r p 484 690 2048 2048 690 690 +i r p 485 690 2048 2048 690 690 +i r p 480 656 1024 1024 656 656 +i r p 480 128 3 3 128 128 +i r p 1024 512 515 515 512 512 +i r p 1024 2048 1024 1024 2048 2048 +i r p 1024 2048 515 515 2048 2048 +i r p 1024 1040 515 515 1040 1040 +i r p 5 1029 515 515 1029 1029 +i r p 1024 1029 515 515 1029 1029 +i r p 1024 1040 2050 2050 1040 1040 +i r p 1029 1029 2050 2050 1029 1029 +i r R 480 646 2050 2050 646 646 +i r R 481 646 2050 2050 646 646 +i r R 482 646 2050 2050 646 646 +i r R 483 646 2050 2050 646 646 +i r R 484 646 2050 2050 646 646 +i r R 485 646 2050 2050 646 646 +i r R 481 656 2050 2050 656 656 +i r R 482 656 2050 2050 656 656 +i r R 483 656 2050 2050 656 656 +i r R 484 656 2050 2050 656 656 +i r p 485 656 2050 2050 656 656 +i r p 480 672 2050 2050 672 672 +i r p 481 672 2050 2050 672 672 +i r p 482 672 2050 2050 672 672 +i r p 483 672 2050 2050 672 672 +i r p 484 672 2050 2050 672 672 +i r p 485 672 2050 2050 672 672 +i r p 480 688 2050 2050 688 688 +i r p 481 688 2050 2050 688 688 +i r r 482 688 2050 2050 688 688 +i r r 483 688 2050 2050 688 688 +i r r 484 688 2050 2050 688 688 +i r r 485 688 2050 2050 688 688 +i r r 1024 512 64 64 512 512 +i r r 16 256 512 512 256 256 +i r r 480 640 512 512 640 640 +i r r 64 768 512 512 768 768 +i r r 128 128 128 128 128 128 +i r r 1024 64 512 512 64 64 +i r r 1024 256 32 32 256 256 +i r r 1024 512 64 64 512 512 +i r r 480 640 512 512 640 640 +i r p 1024 32 256 256 32 32 +i r P 1024 64 512 512 64 64 +i r P 64 800 320 320 800 800 +i r P 64 768 512 512 768 768 +i r P 16 256 512 512 256 256 +i r P 128 128 128 128 128 128 +i r P 256 512 256 256 512 512 +i r P 1024 1024 1024 1024 1024 1024 +i r P 480 640 1024 1024 640 640 +i r P 480 640 256 256 640 640 +i r P 8 64 32 32 64 64 +i r P 9 64 32 32 64 64 +i r P 10 128 64 64 128 128 +i r P 8 8 8 8 8 8 +i r P 12 12 12 12 12 12 +i r P 25 25 25 25 25 25 +i r P 25 25 20 20 25 25 +f r p 480 20 2050 2050 20 20 +f r p 481 20 2050 2050 20 20 +f r p 482 20 2050 2050 20 20 +f r p 483 20 2050 2050 20 20 +f r R 484 20 2050 2050 20 20 +f r R 485 20 2050 2050 20 20 +f r R 480 39 2050 2050 39 39 +f r R 481 39 2050 2050 39 39 +f r R 482 39 2050 2050 39 39 +f r R 483 39 2050 2050 39 39 +f r R 484 39 2050 2050 39 39 +f r p 485 39 2050 2050 39 39 +f r p 480 50 2050 2050 50 50 +f r p 481 50 2050 2050 50 50 +f r p 482 50 2050 2050 50 50 +f r p 483 50 2050 2050 50 50 +f r p 484 50 2050 2050 50 50 +f r p 485 50 2050 2050 50 50 +f r R 480 1108 2050 2050 1108 1108 +f r R 481 1108 2050 2050 1108 1108 +f r R 482 1108 2050 2050 1108 1108 +f r R 483 1108 2050 2050 1108 1108 +f r R 484 1108 2050 2050 1108 1108 +f r R 485 1108 2050 2050 1108 1108 +f r R 480 1127 2050 2050 1127 1127 +f r R 481 1127 2050 2050 1127 1127 +f r R 482 1127 2050 2050 1127 1127 +f r R 483 1127 2050 2050 1127 1127 +f r p 484 1127 2050 2050 1127 1127 +f r p 485 1127 2050 2050 1127 1127 +f r p 480 1138 2050 2050 1138 1138 +f r p 481 1138 2050 2050 1138 1138 +f r p 482 1138 2050 2050 1138 1138 +f r p 483 1138 2050 2050 1138 1138 +f r p 484 1138 2050 2050 1138 1138 +f r p 485 1138 2050 2050 1138 1138 +f r p 1 1 3 3 1 1 +f r p 1 9 3 3 9 9 +f r p 1 2048 3 3 2048 2048 +f r p 1 2048 5192 5192 2048 2048 +f r p 9 1 3 3 1 1 +f r p 576 1 3500 3500 1 1 +f r p 1 1 1 1 1 1 +f r p 102 1088 1024 1024 1088 1088 +f r p 102 2048 1024 1024 2048 2048 +f r p 485 656 1024 1024 656 656 +f r p 483 656 1024 1024 656 656 +f r p 81 128 3 3 128 128 +f r p 1022 512 515 515 512 512 +f r p 74 512 515 515 512 512 +f r p 253 2048 515 515 2048 2048 +f r p 8192 1040 515 515 1040 1040 +f r p 10 1029 515 515 1029 1029 +f r p 24 1040 2050 2050 1040 1040 +f r p 1024 1029 2050 2050 1029 1029 +f r p 480 660 2050 2050 660 660 +f r p 481 660 2050 2050 660 660 +f r p 482 660 2050 2050 660 660 +f r p 483 660 2050 2050 660 660 +f r p 484 660 2050 2050 660 660 +f r p 485 660 2050 2050 660 660 +f r p 480 679 2050 2050 679 679 +f r p 481 679 2050 2050 679 679 +f r p 482 679 2050 2050 679 679 +f r p 483 679 2050 2050 679 679 +f r p 484 679 2050 2050 679 679 +f r p 485 679 2050 2050 679 679 +f r p 480 690 2050 2050 690 690 +f r p 481 690 2050 2050 690 690 +f r p 482 690 2050 2050 690 690 +f r p 483 690 2050 2050 690 690 +f r p 484 690 2050 2050 690 690 +f r p 485 690 2050 2050 690 690 +f r p 480 660 2048 2048 660 660 +f r p 481 660 2048 2048 660 660 +f r p 482 660 2048 2048 660 660 +f r p 483 660 2048 2048 660 660 +f r p 484 660 2048 2048 660 660 +f r p 485 660 2048 2048 660 660 +f r p 480 679 2048 2048 679 679 +f r p 481 679 2048 2048 679 679 +f r p 482 679 2048 2048 679 679 +f r p 483 679 2048 2048 679 679 +f r p 484 679 2048 2048 679 679 +f r p 485 679 2048 2048 679 679 +f r p 480 690 2048 2048 690 690 +f r p 481 690 2048 2048 690 690 +f r p 482 690 2048 2048 690 690 +f r p 483 690 2048 2048 690 690 +f r p 484 690 2048 2048 690 690 +f r p 485 690 2048 2048 690 690 +f r p 480 656 1024 1024 656 656 +f r p 480 128 3 3 128 128 +f r p 1024 512 515 515 512 512 +f r p 1024 2048 1024 1024 2048 2048 +f r p 1024 2048 515 515 2048 2048 +f r p 1024 1040 515 515 1040 1040 +f r p 5 1029 515 515 1029 1029 +f r p 1024 1029 515 515 1029 1029 +f r p 1024 1040 2050 2050 1040 1040 +f r p 1029 1029 2050 2050 1029 1029 +f r R 480 646 2050 2050 646 646 +f r R 481 646 2050 2050 646 646 +f r R 482 646 2050 2050 646 646 +f r R 483 646 2050 2050 646 646 +f r R 484 646 2050 2050 646 646 +f r R 485 646 2050 2050 646 646 +f r R 481 656 2050 2050 656 656 +f r R 482 656 2050 2050 656 656 +f r R 483 656 2050 2050 656 656 +f r R 484 656 2050 2050 656 656 +f r p 485 656 2050 2050 656 656 +f r p 480 672 2050 2050 672 672 +f r p 481 672 2050 2050 672 672 +f r p 482 672 2050 2050 672 672 +f r p 483 672 2050 2050 672 672 +f r p 484 672 2050 2050 672 672 +f r p 485 672 2050 2050 672 672 +f r p 480 688 2050 2050 688 688 +f r p 481 688 2050 2050 688 688 +f r r 482 688 2050 2050 688 688 +f r r 483 688 2050 2050 688 688 +f r r 484 688 2050 2050 688 688 +f r r 485 688 2050 2050 688 688 +f r r 1024 512 64 64 512 512 +f r r 16 256 512 512 256 256 +f r r 480 640 512 512 640 640 +f r r 64 768 512 512 768 768 +f r r 128 128 128 128 128 128 +f r r 1024 64 512 512 64 64 +f r r 1024 256 32 32 256 256 +f r r 1024 512 64 64 512 512 +f r r 480 640 512 512 640 640 +f r p 1024 32 256 256 32 32 +f r P 1024 64 512 512 64 64 +f r P 64 800 320 320 800 800 +f r P 64 768 512 512 768 768 +f r P 16 256 512 512 256 256 +f r P 128 128 128 128 128 128 +f r P 256 512 256 256 512 512 +f r P 1024 1024 1024 1024 1024 1024 +f r P 480 640 1024 1024 640 640 +f r P 480 640 256 256 640 640 +f r P 8 64 32 32 64 64 +f r P 9 64 32 32 64 64 +f r P 10 128 64 64 128 128 +f r P 8 8 8 8 8 8 +f r P 12 12 12 12 12 12 +f r P 25 25 25 25 25 25 +f r P 25 25 20 20 25 25 +i r r 4096 256 5 5 256 256 +i r r 3000 256 128 128 256 256 +i r r 4096 1024 512 512 1024 1024 +i r r 144 256 5 5 256 256 +i r r 144 256 128 128 256 256 +i r r 144 1024 512 512 1024 1024 +i r r 480 688 256 256 688 688 +i r r 480 640 512 512 640 640 +i r r 480 640 1024 1024 640 640 +i r r 64 800 320 320 800 800 +i r r 64 768 512 512 768 768 +i r r 16 256 512 512 256 256 +i r r 128 128 128 128 128 128 +i r r 256 512 256 256 512 512 +i r r 1024 1024 1024 1024 1024 1024 +i r r 1024 32 256 256 32 32 +i r r 1024 64 512 512 64 64 +i r r 1024 256 32 32 256 256 +i r r 1024 512 64 64 512 512 +i r r 512 32 256 256 32 32 +i r r 512 768 512 512 768 768 +i r r 512 256 32 32 256 256 +i r r 512 512 64 64 512 512 +i r r 512 256 768 768 256 256 +i r r 768 768 1024 1024 768 768 +i r r 768 768 768 768 768 768 +i r r 2048 2048 2048 2048 2048 2048 +i r r 4096 4096 4096 4096 4096 4096 +f r r 4096 256 5 5 256 256 +f r r 3000 256 128 128 256 256 +f r r 4096 1024 512 512 1024 1024 +f r r 144 256 5 5 256 256 +f r r 144 256 128 128 256 256 +f r r 144 1024 512 512 1024 1024 +f r r 480 688 256 256 688 688 +f r r 480 640 512 512 640 640 +f r r 480 640 1024 1024 640 640 +f r r 64 800 320 320 800 800 +f r r 64 768 512 512 768 768 +f r r 16 256 512 512 256 256 +f r r 128 128 128 128 128 128 +f r r 256 512 256 256 512 512 +f r r 1024 1024 1024 1024 1024 1024 +f r r 1024 32 256 256 32 32 +f r r 1024 64 512 512 64 64 +f r r 1024 256 32 32 256 256 +f r r 1024 512 64 64 512 512 +f r r 512 32 256 256 32 32 +f r r 512 768 512 512 768 768 +f r r 512 256 32 32 256 256 +f r r 512 512 64 64 512 512 +f r r 512 256 768 768 256 256 +f r r 768 768 1024 1024 768 768 +f r r 768 768 768 768 768 768 +f r r 2048 2048 2048 2048 2048 2048 +f r r 4096 4096 4096 4096 4096 4096 +f r r 2048 1024 1024 1024 1024 1024 +f r r 2048 4096 1024 1024 4096 4096 +f r r 2048 1024 4096 4096 1024 1024 +f r r 2048 1024 2 2 1024 1024 +f r r 128 1024 1024 1024 1024 1024 +f r r 1536 768 768 768 768 768 +f r r 1536 3072 768 768 3072 3072 +f r r 1536 768 3072 3072 768 768 +f r r 1536 768 2 2 768 768 +f r r 128 768 768 768 768 768 +f r r 1024 8 13 13 8 8 +f r r 1024 4 8 8 4 4 +f r r 1024 128 355 355 128 128 +f r r 1024 64 128 128 64 64 +f r r 1024 1 64 64 1 1 +f r r 480 1 256 256 1 1 +f r r 480 256 512 512 256 256 +f r r 480 1024 845 845 1024 1024 +f r r 480 512 1024 1024 512 512 +f r r 10 17191 128 128 17191 17191 +f r r 10 512 256 256 512 512 diff --git a/bench/bench_aocl_gemm/bench_lpgemm.c b/bench/bench_aocl_gemm/bench_lpgemm.c new file mode 100644 index 0000000000..92b7a7a1a6 --- /dev/null +++ b/bench/bench_aocl_gemm/bench_lpgemm.c @@ -0,0 +1,1228 @@ +#include +#include +#include +#include +#include +#include +#include + +#include "blis.h" + +#define S8_MIN (-128) +#define S8_MAX (+127) + +// Mode can be one of the follwoing: +// 1. p - performance, used for benchmarks. +// 2. a - accuracy, used to test accuracy/correctness. +// Default value is p, can be modified by passing command line arg. +char bench_mode = 'p'; + +int32_t global_n_repeat = 0; + +char global_dscale_out = 'n'; + +#define _XSTR(str) #str +#define XSTR(str) _XSTR(str) + +#define GEN_FUNC_NAME(prototype,ctype) prototype ## ctype + +#define GEN_FILL_ARRAY_FUNC(ctype) \ +void fill_array_ ## ctype ( void* arr, dim_t size ) \ +{ \ + ctype* temp_arr = ( ctype* ) arr; \ + for ( dim_t i = 0; i < size; ++i ) \ + { \ + temp_arr[i] = ( ctype )( i % 10 ); \ + } \ +} \ + +GEN_FILL_ARRAY_FUNC(uint8_t) +GEN_FILL_ARRAY_FUNC(int8_t) +GEN_FILL_ARRAY_FUNC(float) +GEN_FILL_ARRAY_FUNC(int32_t) + +inline void float_to_bf16( float* float_value, bfloat16* bf16_val ) +{ + /*Set offset 2 to copy most significant 2 bytes of float + to convert float values to bf16 values*/ + memcpy( ( bf16_val ), (char *)( float_value ) + 2, sizeof ( bfloat16 ) ); +} + +inline void convert_float_arr_to_bf16( float* array, bfloat16* array_bf16, int size ) +{ + for (int i=0; i< size; i++) + { + float_to_bf16( ( array + i ), ( array_bf16 + i ) ); + } +} + +#define GEN_FILL_ARRAY_POST_OPS_FUNC(ctype) \ +void fill_array_post_ops_ ## ctype ( void* arr, dim_t size ) \ +{ \ + ctype* temp_arr = ( ctype* ) arr; \ + for ( dim_t i = 0; i < size; ++i ) \ + { \ + temp_arr[i] = ( ctype )( i % 20 ); \ + } \ +} \ + +GEN_FILL_ARRAY_POST_OPS_FUNC(int16_t) +GEN_FILL_ARRAY_POST_OPS_FUNC(int32_t) +GEN_FILL_ARRAY_POST_OPS_FUNC(float) + +#define GEN_BLIS_MAT_MUL_FUNC(A_type,B_type,C_type,ACCUM_type,BLAS_SFX) \ +void mat_mul_ ## BLAS_SFX \ + ( \ + char stor_order, \ + char op_t, \ + dim_t m, \ + dim_t n, \ + dim_t k, \ + ACCUM_type alpha, \ + A_type* a, \ + dim_t lda, \ + B_type* b, \ + dim_t ldb, \ + ACCUM_type beta, \ + C_type* c, \ + dim_t ldc, \ + aocl_post_op* post_op\ + ) \ +{ \ + char storage = stor_order; \ + char transa = 'n'; \ + char transb = 'n'; \ + char reordera = 'n'; \ + char reorderb = 'n'; \ + \ + if ( ( op_t == 'p' ) || ( op_t == 'P' ) ) \ + { \ + /* No reordering of B.*/ \ + reordera = 'n'; \ + reorderb = 'n'; \ + } \ + else if ( ( op_t == 'r' ) || ( op_t == 'R' ) ) \ + { \ + /* Reordered B.*/ \ + reordera = 'n'; \ + reorderb = 'r'; \ + } \ + \ + aocl_gemm_ ## BLAS_SFX( storage, transa, transb, m, n, k, \ + alpha, \ + a, lda, reordera, \ + b, ldb, reorderb, \ + beta, \ + c, ldc, post_op ); \ + \ + /*dim_t MR = 6; \ + dim_t NR = 16; \ + \ + __m512i selector1; \ + __m512i all_zero = _mm512_setzero_epi32(); \ + __m512i c0; \ + __m512i c1; \ + __m512i c2; \ + __m512i c3; \ + __m512i c4; \ + __m512i c5; \ + \ + for ( dim_t i = 0; i < m; i += MR ) \ + { \ + if ( ( i + MR ) > m ) \ + { \ + break; \ + } \ + for ( dim_t j = 0; j < n; j += NR ) \ + { \ + if ( ( j + NR ) > n ) \ + { \ + break; \ + } \ + selector1 = _mm512_loadu_epi32( (int32_t*)post_op->bias.bias + j ); \ + c0 = _mm512_loadu_epi32( c + ( ( i + 0 ) * ldc ) + j ); \ + c1 = _mm512_loadu_epi32( c + ( ( i + 1 ) * ldc ) + j ); \ + c2 = _mm512_loadu_epi32( c + ( ( i + 2 ) * ldc ) + j ); \ + c3 = _mm512_loadu_epi32( c + ( ( i + 3 ) * ldc ) + j ); \ + c4 = _mm512_loadu_epi32( c + ( ( i + 4 ) * ldc ) + j ); \ + c5 = _mm512_loadu_epi32( c + ( ( i + 5 ) * ldc ) + j ); \ + \ + c0 = _mm512_add_epi32( selector1, c0 ); \ + c1 = _mm512_add_epi32( selector1, c1 ); \ + c2 = _mm512_add_epi32( selector1, c2 ); \ + c3 = _mm512_add_epi32( selector1, c3 ); \ + c4 = _mm512_add_epi32( selector1, c4 ); \ + c5 = _mm512_add_epi32( selector1, c5 ); \ + \ + c0 = _mm512_max_epi32( all_zero, c0 ); \ + c1 = _mm512_max_epi32( all_zero, c1 ); \ + c2 = _mm512_max_epi32( all_zero, c2 ); \ + c3 = _mm512_max_epi32( all_zero, c3 ); \ + c4 = _mm512_max_epi32( all_zero, c4 ); \ + c5 = _mm512_max_epi32( all_zero, c5 ); \ + \ + _mm512_storeu_epi32( c + ( ( i + 0 ) * ldc ) + j, c0 ); \ + _mm512_storeu_epi32( c + ( ( i + 1 ) * ldc ) + j, c1 ); \ + _mm512_storeu_epi32( c + ( ( i + 2 ) * ldc ) + j, c2 ); \ + _mm512_storeu_epi32( c + ( ( i + 3 ) * ldc ) + j, c3 ); \ + _mm512_storeu_epi32( c + ( ( i + 4 ) * ldc ) + j, c4 ); \ + _mm512_storeu_epi32( c + ( ( i + 5 ) * ldc ) + j, c5 ); \ + } \ + } */\ +} \ + +GEN_BLIS_MAT_MUL_FUNC(uint8_t,int8_t,int16_t,int16_t,u8s8s16os16) +GEN_BLIS_MAT_MUL_FUNC(uint8_t,int8_t,int8_t,int16_t,u8s8s16os8) +GEN_BLIS_MAT_MUL_FUNC(uint8_t,int8_t,int32_t,int32_t,u8s8s32os32) +GEN_BLIS_MAT_MUL_FUNC(uint8_t,int8_t,int8_t,int32_t,u8s8s32os8) +GEN_BLIS_MAT_MUL_FUNC(bfloat16,bfloat16,float,float,bf16bf16f32of32) +GEN_BLIS_MAT_MUL_FUNC(bfloat16,bfloat16,bfloat16,float,bf16bf16f32obf16) +GEN_BLIS_MAT_MUL_FUNC(float,float,float,float,f32f32f32of32) + +double get_gflops + ( + dim_t m, + dim_t n, + dim_t k, + double runtime + ) +{ + return ( ( 2.0 * m * n * k ) / ( runtime * 1.0e9 ) ); +} + +void print_result + ( + const char* msg, + int32_t n_repeats, + dim_t m, + dim_t n, + dim_t k, + dim_t lda, + dim_t ldb, + dim_t ldc, + double runtime + ) +{ + double gflops = get_gflops( m, n, k, runtime ); + printf("%s m: %ld, n: %ld, k: %ld, lda: %ld, ldb: %ld, ldc: %ld," \ + " Gops: %f, n_repeats: %d\n", + msg, m, n, k, lda, ldb, ldc, gflops, n_repeats); +} + +#define GEN_MAT_MUL_BENCH_DRV_FUNC(A_type,B_type,C_type,ACCUM_type,BLAS_SFX) \ +void mat_mul_bench_driver_ ## BLAS_SFX \ + ( \ + char stor_order, \ + char op_t, \ + int32_t n_repeats, \ + dim_t m, \ + dim_t n, \ + dim_t k, \ + ACCUM_type alpha, \ + A_type* a, \ + dim_t lda, \ + B_type* b, \ + dim_t ldb, \ + ACCUM_type beta, \ + C_type* c, \ + dim_t ldc, \ + aocl_post_op* post_op\ + ) \ +{ \ + double min_time_diff = DBL_MAX; \ + for ( int32_t nr = 0; nr < n_repeats; ++nr ) \ + { \ + if ( bench_mode == 'a' ) \ + { \ + memset( ( void* ) c, 0, sizeof( C_type ) * m * n ); \ + } \ + \ + struct timespec tstart={0,0}, tend={0,0}; \ + clock_gettime(CLOCK_MONOTONIC, &tstart); \ + \ + GEN_FUNC_NAME(mat_mul_,BLAS_SFX) \ + ( \ + stor_order, op_t, m, n, k, \ + alpha, \ + a, lda, \ + b, ldb, \ + beta, \ + c, ldc, \ + post_op \ + ); \ + \ + clock_gettime(CLOCK_MONOTONIC, &tend); \ + \ + double diff = \ + ( ( double ) tend.tv_sec + ( 1.0e-9 * tend.tv_nsec ) ) - \ + ( ( double ) tstart.tv_sec + ( 1.0e-9 * tstart.tv_nsec ) ); \ + min_time_diff = ( diff < min_time_diff ) ? diff : min_time_diff; \ + } \ + \ + print_result( XSTR(BLAS_SFX), n_repeats, m, n, k, lda, ldb, ldc, min_time_diff); \ +} \ + +GEN_MAT_MUL_BENCH_DRV_FUNC(uint8_t,int8_t,int16_t,int16_t,u8s8s16os16) +GEN_MAT_MUL_BENCH_DRV_FUNC(uint8_t,int8_t,int8_t,int16_t,u8s8s16os8) +GEN_MAT_MUL_BENCH_DRV_FUNC(uint8_t,int8_t,int32_t,int32_t,u8s8s32os32) +GEN_MAT_MUL_BENCH_DRV_FUNC(uint8_t,int8_t,int8_t,int32_t,u8s8s32os8) +GEN_MAT_MUL_BENCH_DRV_FUNC(bfloat16,bfloat16,float,float,bf16bf16f32of32) +GEN_MAT_MUL_BENCH_DRV_FUNC(bfloat16,bfloat16,bfloat16,float,bf16bf16f32obf16) +GEN_MAT_MUL_BENCH_DRV_FUNC(float,float,float,float,f32f32f32of32) + +int max (int a, int b) +{ + return ( a > b ? a : b ); +} + +int min (int a, int b) +{ + return ( a < b ? a : b ); +} + +#define GEN_MAT_MUL_ACC_CHK_DOWNSCALE(C_type,ACCUM_type,SCALE_type,BLAS_DOWNSCALE_SFX) \ +inline C_type mat_mul_accuracy_check_downscale_ ## BLAS_DOWNSCALE_SFX \ + (\ + ACCUM_type temp_accum,\ + C_type out_temp_accum, \ + aocl_post_op* post_op, \ + dim_t j \ + )\ +{\ + out_temp_accum = ( C_type ) min ( max ( nearbyintf( ( SCALE_type )temp_accum * \ + ( *( ( SCALE_type* )post_op->sum.scale_factor + j ) ) ), S8_MIN ), S8_MAX ) ; \ + return out_temp_accum; \ +}\ + +GEN_MAT_MUL_ACC_CHK_DOWNSCALE(int8_t,int16_t,float,u8s8s16os8) +GEN_MAT_MUL_ACC_CHK_DOWNSCALE(int8_t,int32_t,float,u8s8s32os8) + +inline bfloat16 mat_mul_accuracy_check_downscale_bf16bf16f32obf16 + ( + float temp_accum, + bfloat16 out_temp_accum, + aocl_post_op* post_op, + dim_t j + ) +{ + float_to_bf16( ( &temp_accum ), ( &out_temp_accum ) ); + return out_temp_accum; +} + +#define GEN_MAT_MUL_ACC_CHK_ACCUM(A_type, B_type, C_type,ACCUM_type,BLAS_SFX) \ +inline ACCUM_type mat_mul_accuracy_check_accum_ ## BLAS_SFX \ + (\ + A_type* a, \ + B_type* b, \ + C_type* c_ref, \ + ACCUM_type temp_accum,\ + ACCUM_type alpha, \ + ACCUM_type beta, \ + dim_t rs_a, \ + dim_t rs_b, \ + dim_t cs_a, \ + dim_t cs_b, \ + dim_t rs_c_ref, \ + dim_t cs_c_ref, \ + dim_t i, \ + dim_t j, \ + dim_t k \ + )\ +{\ + for ( dim_t p = 0; p < k; ++p) \ + { \ + temp_accum += ( *( a + ( i * rs_a ) + ( cs_a * p ) ) * \ + *( b + ( rs_b * p ) + ( cs_b * j ) ) ); \ + } \ +\ + temp_accum = ( beta * ( * (c_ref + ( rs_c_ref * i ) + ( cs_c_ref * j ) ) ) ) \ + + ( alpha * temp_accum ); \ + return temp_accum; \ +}\ + +GEN_MAT_MUL_ACC_CHK_ACCUM(uint8_t,int8_t,int8_t,int16_t,u8s8s16os8) +GEN_MAT_MUL_ACC_CHK_ACCUM(uint8_t,int8_t,int16_t,int16_t,u8s8s16os16) +GEN_MAT_MUL_ACC_CHK_ACCUM(uint8_t,int8_t,int8_t,int32_t,u8s8s32os8) +GEN_MAT_MUL_ACC_CHK_ACCUM(uint8_t,int8_t,int32_t,int32_t,u8s8s32os32) +GEN_MAT_MUL_ACC_CHK_ACCUM(float,float,float,float,f32f32f32of32) + +inline float bf16_to_float + ( + bfloat16 bf16_val + ) +{ + int32_t inter_temp = *( ( int16_t* ) &bf16_val ); + inter_temp = inter_temp << 16; + float float_value = *( float* ) ( &inter_temp ); + return float_value; +} + +inline float mat_mul_accuracy_check_accum_bf16bf16f32of32 + ( + bfloat16* a, + bfloat16* b, + float* c_ref, + float temp_accum, + float alpha, + float beta, + dim_t rs_a, + dim_t rs_b, + dim_t cs_a, + dim_t cs_b, + dim_t rs_c_ref, + dim_t cs_c_ref, + dim_t i, + dim_t j, + dim_t k + ) +{ + for ( dim_t p = 0; p < k; ++p) + { + float a_float = bf16_to_float( *( a + i * rs_a + p * cs_a ) ); + float b_float = bf16_to_float( *( b + p * rs_b + j * cs_b ) ); + temp_accum += ( ( a_float ) * ( b_float ) ); + } + temp_accum = ( beta * ( * (c_ref + ( rs_c_ref * i ) + ( cs_c_ref * j ) ) ) ) + + ( alpha * temp_accum ); + return temp_accum; +} + +inline float mat_mul_accuracy_check_accum_bf16bf16f32obf16 + ( + bfloat16* a, + bfloat16* b, + bfloat16* c_ref, + float temp_accum, + float alpha, + float beta, + dim_t rs_a, + dim_t rs_b, + dim_t cs_a, + dim_t cs_b, + dim_t rs_c_ref, + dim_t cs_c_ref, + dim_t i, + dim_t j, + dim_t k + ) +{ + for ( dim_t p = 0; p < k; ++p) + { + float a_float = bf16_to_float( *( a + i*rs_a + p*cs_a ) ); + float b_float = bf16_to_float( *( b + p*rs_b + j*cs_b ) ); + temp_accum += ( ( a_float ) * ( b_float ) ); + } + float c_ref_float = bf16_to_float( *( c_ref + i*rs_c_ref + j*cs_c_ref ) ); + temp_accum = ( beta * ( c_ref_float ) ) + ( alpha * temp_accum ); + + return temp_accum; +} + +#define GEN_MAT_MUL_ACC_CHK_DRV_FUNC(A_type,B_type,C_type,ACCUM_type,SCALE_type,BLAS_SFX,BLAS_DOWNSCALE_SFX) \ +void mat_mul_accuracy_check_driver_ ## BLAS_SFX \ + ( \ + FILE* fout, \ + const char stor_order, \ + dim_t m, \ + dim_t n, \ + dim_t k, \ + ACCUM_type alpha, \ + A_type* a, \ + dim_t lda, \ + B_type* b, \ + dim_t ldb, \ + ACCUM_type beta, \ + C_type* c, \ + dim_t ldc, \ + C_type* c_ref, \ + dim_t ldc_ref, \ + aocl_post_op* post_op\ + ) \ +{ \ + dim_t rs_a = lda; \ + dim_t cs_a = 1; \ + dim_t rs_b = ldb; \ + dim_t cs_b = 1; \ + dim_t rs_c = ldc; \ + dim_t cs_c = 1; \ + dim_t rs_c_ref = ldc_ref; \ + dim_t cs_c_ref = 1; \ + \ + if ( ( stor_order == 'C' ) || ( stor_order == 'c' ) ) \ + { \ + rs_a = 1; \ + cs_a = lda; \ + rs_b = 1; \ + cs_b = ldb; \ + rs_c = 1; \ + cs_c = ldc; \ + rs_c_ref = 1; \ + cs_c_ref = ldc_ref; \ + } \ + \ + for ( dim_t i = 0; i < m; ++i ) \ + { \ + for ( dim_t j = 0; j < n; ++j ) \ + { \ + ACCUM_type temp_accum = 0; \ + C_type out_temp_accum = 0; \ + \ + temp_accum = GEN_FUNC_NAME(mat_mul_accuracy_check_accum_,BLAS_SFX) \ + (a,b,c_ref,temp_accum,alpha,beta,rs_a,rs_b,cs_a,cs_b,rs_c_ref,cs_c_ref,i,j,k); \ +\ + if ( post_op != NULL ) \ + { \ + /* Apply bias followed by relu. */ \ + if ( post_op->seq_vector[0] == BIAS ) \ + { \ + if ( post_op->seq_length >= 1 ) \ + { \ + temp_accum += ( *( ( ACCUM_type* )post_op->bias.bias + j ) ); \ + } \ + if ( ( post_op->seq_length > 1 ) && \ + ( post_op->seq_vector[1] == ELTWISE ) ) \ + { \ + if ( post_op->eltwise.algo.alpha != NULL ) /* PReLU*/ \ + { \ + temp_accum = ( temp_accum > 0 ) ? \ + temp_accum : \ + ( temp_accum * \ + *( ( ACCUM_type* ) post_op->eltwise.algo.alpha ) ); \ + } \ + else \ + { \ + temp_accum = ( temp_accum > 0 ) ? temp_accum : 0 ; \ + } \ + } \ + } \ + else if ( post_op->seq_vector[0] == ELTWISE ) \ + { \ + if ( post_op->seq_length >= 1 ) \ + { \ + if ( post_op->eltwise.algo.alpha != NULL ) /* PReLU*/ \ + { \ + temp_accum = ( temp_accum > 0 ) ? \ + temp_accum : \ + ( temp_accum * *( ( ACCUM_type* ) post_op->eltwise.algo.alpha ) ); \ + } \ + else \ + { \ + temp_accum = ( temp_accum > 0 ) ? temp_accum : 0 ; \ + } \ + } \ + if ( ( post_op->seq_length > 1 ) && ( post_op->seq_vector[1] == BIAS ) ) \ + { \ + temp_accum += ( *( ( ACCUM_type* )post_op->bias.bias + j ) ); \ + } \ + } \ + } \ + if ( global_dscale_out == 'y' ) \ + { \ + out_temp_accum = GEN_FUNC_NAME(mat_mul_accuracy_check_downscale_,BLAS_DOWNSCALE_SFX) \ + (temp_accum, out_temp_accum, post_op, j); \ + } \ + else \ + { \ + out_temp_accum = ( C_type )temp_accum; \ + } \ + \ + if ( *( c + ( rs_c * i ) + ( cs_c * j ) ) != out_temp_accum ) \ + { \ + if ( fout ) \ + { \ + fprintf( fout, "%s Failure input m: %ld, n: %ld, k: %ld," \ + " lda: %ld, ldb: %ld, ldc: %ld\n", \ + XSTR(BLAS_SFX), m, n, k, lda, ldb, ldc ); \ + fflush( fout ); \ + } \ + printf("failure, m: %ld, n: %ld, k: %ld\n", i, j, k ); \ + goto cleanup_acc; \ + } \ + } \ + } \ +cleanup_acc: \ + return; \ +} \ + +GEN_MAT_MUL_ACC_CHK_DRV_FUNC(uint8_t,int8_t,int16_t,int16_t,float,u8s8s16os16,u8s8s16os8) +GEN_MAT_MUL_ACC_CHK_DRV_FUNC(uint8_t,int8_t,int8_t,int16_t,float,u8s8s16os8,u8s8s16os8) +GEN_MAT_MUL_ACC_CHK_DRV_FUNC(uint8_t,int8_t,int32_t,int32_t,float,u8s8s32os32,u8s8s32os8) +GEN_MAT_MUL_ACC_CHK_DRV_FUNC(uint8_t,int8_t,int8_t,int32_t,float,u8s8s32os8,u8s8s32os8) +GEN_MAT_MUL_ACC_CHK_DRV_FUNC(bfloat16,bfloat16,float,float,float,bf16bf16f32of32,bf16bf16f32obf16) +GEN_MAT_MUL_ACC_CHK_DRV_FUNC(bfloat16,bfloat16,bfloat16,float,float,bf16bf16f32obf16,bf16bf16f32obf16) +GEN_MAT_MUL_ACC_CHK_DRV_FUNC(float,float,float,float,float,f32f32f32of32,bf16bf16f32obf16) + +/* Only supports bias followed by RELU and vice versa for now.*/ \ +#define GEN_MAT_MUL_POST_OPS_CREATOR(C_type,DSCALE_type,BLAS_SFX) \ +aocl_post_op* lpgemm_create_post_ops_struct_ ## BLAS_SFX \ + ( \ + dim_t m, \ + dim_t n, \ + char* post_ops_str \ + ) \ +{ \ + aocl_post_op* post_ops = NULL; \ + post_ops = ( aocl_post_op* ) malloc( sizeof( aocl_post_op ) ); \ + \ + if ( ( post_ops == NULL ) && ( global_dscale_out == 'n' ) ) \ + { \ + return NULL; \ + } \ + \ + /* Only supporting 3 post ops at max for now.*/ \ + dim_t max_post_ops_seq_length = 3; \ + post_ops->seq_vector = ( AOCL_POST_OP_TYPE* ) \ + malloc \ + ( \ + max_post_ops_seq_length * \ + sizeof( AOCL_POST_OP_TYPE ) \ + ); \ + \ + if ( post_ops->seq_vector == NULL ) \ + { \ + free( post_ops ); \ + return NULL; \ + } \ + \ + /* Parse post ops list.*/ \ + dim_t cur_op_index = 0; \ + /* Ensure the buffers that use NULL check in deinit code is properly set to NULL.*/ \ + post_ops->eltwise.algo.alpha = NULL; \ + post_ops->bias.bias = NULL; \ + post_ops->sum.scale_factor = NULL; \ + if ( post_ops_str != NULL ) \ + { \ + char* ops_tok = strtok(post_ops_str, ", " ); \ + bool is_param_relu = FALSE; \ + while ( ops_tok ) \ + { \ + if ( strcmp( ops_tok, "bias") == 0 ) \ + { \ + post_ops->seq_vector[cur_op_index] = BIAS; \ + } \ + else if ( strcmp( ops_tok, "relu") == 0 ) \ + { \ + post_ops->seq_vector[cur_op_index] = ELTWISE; \ + } \ + else if ( strcmp( ops_tok, "prelu") == 0 ) \ + { \ + post_ops->seq_vector[cur_op_index] = ELTWISE; \ + is_param_relu = TRUE; \ + } \ + ops_tok = strtok( NULL, ", " ); \ + cur_op_index++; \ + } \ + \ + /* Allocate bias buffer, return early if alloc fails.*/ \ + post_ops->bias.bias = malloc( n * sizeof( C_type ) ); \ + if ( post_ops->bias.bias == NULL ) \ + { \ + free( post_ops->seq_vector ); \ + free( post_ops ); \ + return NULL; \ + } \ + GEN_FUNC_NAME(fill_array_post_ops_,C_type)( post_ops->bias.bias, n ); \ + \ + post_ops->eltwise.is_power_of_2 = FALSE; \ + post_ops->eltwise.scale_factor = NULL; \ + post_ops->eltwise.algo.alpha = NULL; \ + post_ops->eltwise.algo.algo_type = RELU; \ + if ( is_param_relu == TRUE ) \ + { \ + post_ops->eltwise.algo.alpha = malloc( sizeof( C_type ) ); \ + *( ( C_type* ) post_ops->eltwise.algo.alpha ) = ( C_type )6; \ + post_ops->eltwise.algo.algo_type = PRELU; \ + } \ + post_ops->eltwise.algo.beta = NULL; \ + } \ + \ + if ( global_dscale_out == 'y' ) \ + { \ + post_ops->seq_vector[cur_op_index] = SCALE; \ + cur_op_index++; \ + \ + post_ops->sum.is_power_of_2 = FALSE; \ + post_ops->sum.scale_factor = NULL; \ + post_ops->sum.buff = NULL; \ + post_ops->sum.zero_point = NULL; \ + if ( global_dscale_out == 'y' ) \ + { \ + /* Allocate scale buffer, return early if alloc fails.*/ \ + post_ops->sum.scale_factor = malloc( n * sizeof( DSCALE_type ) ); \ + if ( post_ops->sum.scale_factor == NULL ) \ + { \ + free ( post_ops->bias.bias ); \ + free( post_ops->seq_vector ); \ + free( post_ops ); \ + return NULL; \ + } \ + /* Fill scale factor.*/ \ + DSCALE_type* temp_dscale_ptr = ( DSCALE_type* )post_ops->sum.scale_factor; \ + for ( dim_t i = 0; i < n; ++i ) \ + { \ + temp_dscale_ptr[i] = ( ( DSCALE_type )1 )/ ( ( DSCALE_type )1000 ); \ + } \ + } \ + } \ + \ + post_ops->seq_length = cur_op_index; \ + \ + return post_ops; \ +} \ + +GEN_MAT_MUL_POST_OPS_CREATOR(int16_t,float,u8s8s16os16) +GEN_MAT_MUL_POST_OPS_CREATOR(int32_t,float,u8s8s32os32) +GEN_MAT_MUL_POST_OPS_CREATOR(float,float,bf16bf16f32of32) +GEN_MAT_MUL_POST_OPS_CREATOR(float,float,f32f32f32of32) + +void lpgemm_destroy_post_ops_struct( aocl_post_op* post_ops ) +{ + if ( post_ops == NULL ) + { + return; + } + + if ( post_ops->eltwise.algo.alpha != NULL ) + { + free( post_ops->eltwise.algo.alpha ); + } + if ( post_ops->sum.scale_factor != NULL ) + { + free( post_ops->sum.scale_factor ); + } + if ( post_ops->bias.bias != NULL ) + { + free( post_ops->bias.bias ); + } + if( post_ops->seq_vector != NULL ) + { + free( post_ops->seq_vector ); + } + + free( post_ops ); +} + +#define GEN_MAT_MUL_BENCH_MAIN_FUNC(A_type,B_type,C_type,BLAS_SFX,REORDER_SFX) \ +void mat_mul_bench_main_ ## BLAS_SFX \ + ( \ + FILE* fin, \ + FILE* fout, \ + char stor_order, \ + char op_t, \ + int32_t m, \ + int32_t n, \ + int32_t k, \ + int32_t stride_a, \ + int32_t stride_b, \ + int32_t stride_c, \ + char* post_ops_str \ + ) \ +{ \ + if ( ( op_t != 'p' ) && ( op_t != 'P' ) && ( op_t != 'r' ) && ( op_t != 'R' ) ) \ + { \ + printf("The op_t ( 2nd arg in input.txt) is not valid\n"); \ + return; \ + } \ + \ + int32_t n_repeats = bli_max( 30, bli_min(( 3e10 / ( ( int64_t )m * n * k )), 100 )); \ + if ( global_n_repeat > 0 ) \ + { \ + n_repeats = global_n_repeat; \ + } \ + \ + /* Get 64 byte aligned memory.*/ \ + A_type* a = ( A_type* ) bli_malloc_user( sizeof( A_type ) * m * k ); \ + \ + B_type* b = ( B_type* ) bli_malloc_user( sizeof( B_type ) * n * k ); \ + \ + C_type* c = ( C_type* ) bli_malloc_user( sizeof( C_type ) * m * n ); \ + memset( ( void* ) c, 0, sizeof( C_type ) * m * n ); \ + \ + C_type* c_ref = ( C_type* ) bli_malloc_user( sizeof( C_type ) * m * n ); \ + memset( ( void* ) c_ref, 0, sizeof( C_type ) * m * n ); \ + \ + C_type alpha; \ + C_type beta; \ + if ( bench_mode == 'p' ) \ + { \ + alpha = 1; \ + beta = 0; \ + } \ + else if ( bench_mode == 'a' ) \ + { \ + alpha = 2; \ + beta = 9; \ + } \ + \ + GEN_FUNC_NAME(fill_array_,A_type)( a, ( m * k ) ); \ + GEN_FUNC_NAME(fill_array_,B_type)( b, ( k * n ) ); \ + \ + aocl_post_op* post_op = NULL; \ + if ( ( post_ops_str != NULL ) || ( global_dscale_out == 'y' ) ) \ + { \ + post_op = GEN_FUNC_NAME(lpgemm_create_post_ops_struct_,REORDER_SFX)( m, n, post_ops_str ); \ + if ( post_op == NULL ) \ + { \ + printf(" post op struct allocation failure, returning.\n"); \ + return; \ + } \ + } \ + \ + if ( ( op_t == 'p' ) || ( op_t == 'P' ) ) \ + { \ + /* No reordering of B.*/ \ + GEN_FUNC_NAME(mat_mul_bench_driver_,BLAS_SFX) \ + ( \ + stor_order, op_t, n_repeats, m, n, k, \ + alpha, \ + a, stride_a, \ + b, stride_b, \ + beta, \ + c, stride_c, \ + post_op \ + ); \ + } \ + else if ( ( op_t == 'r' ) || ( op_t == 'R' ) ) \ + { \ + /* Reorder B.*/ \ + siz_t b_reorder_buf_siz_req = \ + GEN_FUNC_NAME(aocl_get_reorder_buf_size_,REORDER_SFX)( 'B', k, n ); \ + \ + B_type* b_reorder = ( B_type* ) bli_malloc_user( b_reorder_buf_siz_req ); \ + GEN_FUNC_NAME(aocl_reorder_,REORDER_SFX)( 'B', b, b_reorder, k, n, stride_b ); \ + \ + GEN_FUNC_NAME(mat_mul_bench_driver_,BLAS_SFX) \ + ( \ + stor_order, op_t, n_repeats, m, n, k, \ + alpha, \ + a, stride_a, \ + b_reorder, stride_b, \ + beta, \ + c, stride_c, \ + post_op \ + ); \ + \ + bli_free_user( b_reorder ); \ + } \ + \ + if ( bench_mode == 'a' ) \ + { \ + printf("Running accuracy check.\n"); \ + GEN_FUNC_NAME(mat_mul_accuracy_check_driver_,BLAS_SFX) \ + ( \ + fout, stor_order, m, n, k, \ + alpha, \ + a, stride_a, \ + b, stride_b, \ + beta, \ + c, stride_c, \ + c_ref, stride_c, \ + post_op \ + ); \ + } \ + \ + lpgemm_destroy_post_ops_struct( post_op ); \ + \ + if ( a != NULL ) \ + { \ + bli_free_user( a ); \ + } \ + if ( b != NULL ) \ + { \ + bli_free_user( b ); \ + } \ + if ( c != NULL ) \ + { \ + bli_free_user( c ); \ + } \ + if ( c_ref != NULL ) \ + { \ + bli_free_user( c_ref ); \ + } \ +} \ + +GEN_MAT_MUL_BENCH_MAIN_FUNC(uint8_t,int8_t,int16_t,u8s8s16os16,u8s8s16os16) +GEN_MAT_MUL_BENCH_MAIN_FUNC(uint8_t,int8_t,int8_t,u8s8s16os8,u8s8s16os16) +GEN_MAT_MUL_BENCH_MAIN_FUNC(uint8_t,int8_t,int32_t,u8s8s32os32,u8s8s32os32) +GEN_MAT_MUL_BENCH_MAIN_FUNC(uint8_t,int8_t,int8_t,u8s8s32os8,u8s8s32os32) +GEN_MAT_MUL_BENCH_MAIN_FUNC(float,float,float,f32f32f32of32,f32f32f32of32) + +#define GEN_MAT_MUL_BENCH_MAIN_FUNC_BF16(C_type, BLAS_SFX) \ +void mat_mul_bench_main_ ## BLAS_SFX \ + ( \ + FILE* fin, \ + FILE* fout, \ + char stor_order, \ + char op_t, \ + int32_t m, \ + int32_t n, \ + int32_t k, \ + int32_t stride_a, \ + int32_t stride_b, \ + int32_t stride_c, \ + char* post_ops_str \ + ) \ +{ \ + if ( ( op_t != 'p' ) && ( op_t != 'P' ) && ( op_t != 'r' ) && ( op_t != 'R' ) ) \ + { \ + printf("The op_t ( 2nd arg in input.txt) is not valid\n");\ + return; \ + } \ + \ + int32_t n_repeats = bli_max( 30, bli_min(( 3e10 / ( ( int64_t )m * n * k )), 1000 )); \ + if ( global_n_repeat > 0 ) \ + { \ + n_repeats = global_n_repeat; \ + } \ + \ + /* Get 64 byte aligned memory.*/ \ + bfloat16* a = ( bfloat16* ) bli_malloc_user( sizeof( bfloat16 ) * m * k ); \ + float *a_float = bli_malloc_user( m * k * sizeof( float )); \ + for ( int32_t i = 0; i < m*k; ++i ) \ + { \ + a_float[i] = ( float ) ( i % 5 ); \ + } \ + convert_float_arr_to_bf16( a_float, a, m * k ); \ + \ + bfloat16* b = ( bfloat16* ) bli_malloc_user( sizeof( bfloat16 ) * n * k ); \ + float *b_float = bli_malloc_user( k * n * sizeof( float )); \ + for ( int32_t i = 0; i < k*n; ++i ) \ + { \ + b_float[i] = ( float ) ( i % 5 );\ + } \ + convert_float_arr_to_bf16( b_float, b, k * n ); \ + \ + C_type* c = ( C_type* ) bli_malloc_user( sizeof( C_type ) * m * n ); \ + memset( ( void* ) c, 0, sizeof( C_type ) * m * n ); \ + \ + C_type* c_ref = ( C_type* ) bli_malloc_user( sizeof( C_type ) * m * n ); \ + memset( ( void* ) c_ref, 0, sizeof( C_type ) * m * n ); \ + \ + float alpha; \ + float beta; \ + if ( bench_mode == 'p' ) \ + { \ + alpha = 1; \ + beta = 0; \ + } \ + else if ( bench_mode == 'a' ) \ + { \ + alpha = 2; \ + beta = 9; \ + } \ + \ + aocl_post_op* post_op = NULL; \ + if ( ( post_ops_str != NULL ) || ( global_dscale_out == 'y' ) ) \ + { \ + post_op = lpgemm_create_post_ops_struct_bf16bf16f32of32( m, n, post_ops_str ); \ + if ( post_op == NULL ) \ + { \ + printf(" post op struct allocation failure, returning.\n"); \ + return; \ + } \ + } \ + \ + if ( ( op_t == 'p' ) || ( op_t == 'P' ) ) \ + { \ + /* No reordering of B.*/ \ + GEN_FUNC_NAME(mat_mul_bench_driver_,BLAS_SFX) \ + ( \ + stor_order, op_t, n_repeats, m, n, k, \ + alpha, \ + a, stride_a, \ + b, stride_b, \ + beta, \ + c, stride_c, \ + post_op \ + ); \ + } \ + else if ( ( op_t == 'r' ) || ( op_t == 'R' ) ) \ + { \ + /* Reorder B.*/ \ + siz_t b_reorder_buf_siz_req = \ + aocl_get_reorder_buf_size_bf16bf16f32of32( 'B', k, n ); \ + \ + bfloat16* b_reorder = ( bfloat16* ) bli_malloc_user( b_reorder_buf_siz_req ); \ + aocl_reorder_bf16bf16f32of32( 'B', b, b_reorder, k, n, stride_b ); \ + \ + GEN_FUNC_NAME(mat_mul_bench_driver_,BLAS_SFX) \ + ( \ + stor_order, op_t, n_repeats, m, n, k, \ + alpha, \ + a, stride_a, \ + b_reorder, stride_b, \ + beta, \ + c, stride_c, \ + post_op \ + ); \ + } \ + \ +if ( bench_mode == 'a' ) \ + { \ + printf(" Running accuracy check.\n"); \ + GEN_FUNC_NAME(mat_mul_accuracy_check_driver_,BLAS_SFX) \ + ( \ + fout, stor_order, m, n, k, \ + alpha, \ + a, stride_a, \ + b, stride_b, \ + beta, \ + c, stride_c, \ + c_ref, stride_c, \ + post_op \ + ); \ + } \ + \ + lpgemm_destroy_post_ops_struct( post_op ); \ + \ + if ( a != NULL ) \ + { \ + bli_free_user( a ); \ + } \ + if ( b != NULL ) \ + { \ + bli_free_user( b ); \ + } \ + if ( a_float != NULL ) \ + { \ + bli_free_user( a_float ); \ + } \ + if ( b_float != NULL ) \ + { \ + bli_free_user( b_float ); \ + } \ + if ( c != NULL ) \ + { \ + bli_free_user( c ); \ + } \ + if ( c_ref != NULL ) \ + { \ + bli_free_user( c_ref ); \ + } \ +} \ + +GEN_MAT_MUL_BENCH_MAIN_FUNC_BF16(float,bf16bf16f32of32) +GEN_MAT_MUL_BENCH_MAIN_FUNC_BF16(bfloat16,bf16bf16f32obf16) + +int main( int argc, char** argv ) +{ + FILE* fin = NULL; + if ( argc < 5 ) + { + printf( "Usage: ./mat_mul -i input.txt -m mode < -n 1000 -o op1,op2.. >" \ + "\nMode is either a or p. a is used for accuracy test, " \ + "whereas p is used for performance benchmarking." \ + "\nn_repeats can be set optionally using -n arg." \ + "\nPost ops can be executed optionaly by providing a " \ + "coma separated list of ops after -o arg.\nCurrently " \ + "bias and relu/prelu is supported and can be specified " \ + "as a single post op or combination of the same. eg: -o bias,relu ; -o prelu." \ + "\nDownscaled version of an API can be enabled by using -d arg. " \ + "downscale is used to enable- u8s8s32os8, u8s8s16os8 or bf16bf16f32obf16 \n" ); + exit( 1 ); + } + + char* file_name = NULL; + char* post_ops_str = NULL; + char* post_ops_str_dest = NULL; //Strtok is used to parse, need to maintain a copy. + + // Parse CLI arguments. + opterr = 0; + int opt_val; + while ( ( opt_val = getopt( argc, argv, "i:m:n:o:d" ) ) != -1 ) + { + switch ( opt_val ) + { + case 'i': + file_name = optarg; + break; + case 'm': + bench_mode = ( ( ( *optarg ) == 'a' ) || ( ( *optarg ) == 'p' ) ) ? ( *optarg ) : 'p'; + break; + case 'n': + global_n_repeat = ( atoi( optarg ) > 0 ) ? atoi( optarg ) : 0; + break; + case 'o': + post_ops_str = optarg; + break; + case 'd': + global_dscale_out = 'y'; + break; + default: + break; + } + } + + if ( post_ops_str != NULL ) + { + post_ops_str_dest = strdup( post_ops_str ); + } + + if ( bench_mode == 'p' ) + { + printf( "Running bench in performance benchmarking mode.\n" ); + } + else if ( bench_mode == 'a' ) + { + printf( "Running bench in accuracy/correctness testing mode.\n" ); + } + + if ( file_name == NULL ) + { + printf( " File name provided is invalid.\n" ); + exit( 1 ); + } + + fin = fopen( file_name, "r" ); + if (fin == NULL) + { + printf( "Error opening the file %s\n", argv[1] ); + exit( 1 ); + } + + FILE* fout = NULL; + + fout = fopen( "lpgemm_accuracy_test_failures.txt", "w" ); + + char op_type_char; + char op_t; + char stor_order; + int32_t m, n, k; + int32_t stride_a, stride_b, stride_c; + + const dim_t len_list_omp_cores_for_testing = 2; + const dim_t list_omp_cores_for_testing[2] = { 80, 1 }; + + dim_t core_index = 0; + bool can_run = TRUE; + while ( ( can_run == TRUE ) && ( fseek( fin, 0L, SEEK_SET ) == 0 ) ) + { + if ( bench_mode == 'p' ) + { + can_run = FALSE; + } + else if ( bench_mode == 'a' ) + { + // For accuracy testing, we test accuracy using multiple different + // number of cores. This helps uncover any bugs related to over + // subscription or varying thread factorizations. + // Set current number of cores. +#ifdef BLIS_ENABLE_OPENMP + omp_set_num_threads( list_omp_cores_for_testing[core_index] ); +#endif + printf( "Accuracy test using %ld threads.\n", + list_omp_cores_for_testing[core_index] ); + + core_index++; + if ( core_index < len_list_omp_cores_for_testing ) + { + can_run = TRUE; + } + else + { + can_run = FALSE; + } + } + + // Input format: data_type stor_type pack/reorder m n k lda ldb ldc + while ( fscanf( fin, "%c %c %c %d %d %d %d %d %d\n", + &op_type_char, &stor_order, &op_t, &m, &n, &k, + &stride_a, &stride_b, &stride_c ) == 9 ) + { + stor_order = ( ( stor_order == 'r' ) || ( stor_order == 'R' ) || + ( stor_order == 'c' ) || ( stor_order == 'C' ) ) ? + stor_order : 'r'; + + if ( ( op_type_char == 'i' ) || ( op_type_char == 'I' ) ) + { + if ( global_dscale_out == 'n' ) + { + GEN_FUNC_NAME(mat_mul_bench_main_,u8s8s32os32) + ( + fin, fout, stor_order, op_t, + m, n, k, stride_a, stride_b, stride_c, + post_ops_str_dest + ); + } + else + { + GEN_FUNC_NAME(mat_mul_bench_main_,u8s8s32os8) + ( + fin, fout, stor_order, op_t, + m, n, k, stride_a, stride_b, stride_c, + post_ops_str_dest + ); + } + } + else if ( ( op_type_char == 'f' ) || ( op_type_char == 'F' ) ) + { + GEN_FUNC_NAME(mat_mul_bench_main_,f32f32f32of32) + ( + fin, fout, stor_order, op_t, + m, n, k, stride_a, stride_b, stride_c, + NULL + ); + } + else if ((op_type_char == 's') || (op_type_char == 'S')) + { + if ( global_dscale_out == 'n' ) + { + GEN_FUNC_NAME(mat_mul_bench_main_,u8s8s16os16) + ( + fin, fout, stor_order, op_t, + m, n, k, stride_a, stride_b, stride_c, + post_ops_str_dest + ); + } + else + { + GEN_FUNC_NAME(mat_mul_bench_main_,u8s8s16os8) + ( + fin, fout, stor_order, op_t, + m, n, k, stride_a, stride_b, stride_c, + post_ops_str_dest + ); + } + } + if ((op_type_char == 'b') || (op_type_char == 'B')) + { + if ( global_dscale_out == 'n' ) + { + GEN_FUNC_NAME(mat_mul_bench_main_, bf16bf16f32of32) + ( + fin, fout, stor_order, op_t, + m, n, k, stride_a, stride_b, stride_c, + post_ops_str_dest + ); + } + else + { + GEN_FUNC_NAME(mat_mul_bench_main_, bf16bf16f32obf16) + ( + fin, fout, stor_order, op_t, + m, n, k, stride_a, stride_b, stride_c, + post_ops_str_dest + ); + } + } + if ( post_ops_str != NULL ) + { + strcpy( post_ops_str_dest, post_ops_str ); + } + } + } + + if ( post_ops_str_dest != NULL ) + { + free( post_ops_str_dest ); + } + if ( fin ) + { + fclose( fin ); + } + if ( fout ) + { + fclose( fout ); + } + return 0; +} diff --git a/bench/bench_aocl_gemm/data_gen_lpgemm.py b/bench/bench_aocl_gemm/data_gen_lpgemm.py new file mode 100644 index 0000000000..3bc3a24421 --- /dev/null +++ b/bench/bench_aocl_gemm/data_gen_lpgemm.py @@ -0,0 +1,78 @@ +# +# +# BLIS +# An object-based framework for developing high-performance BLAS-like +# libraries. +# +# Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# - Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# - Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# - Neither the name(s) of the copyright holder(s) nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# + +# Initializing global mnk_array.This array will be used to store all mnk values +mnk_array = [] + +max_elem = 2500; +out_file_name = "accuracy_test_data_lpgemm.txt" +# Important mnk generator function.This will generate all possible combinations +# of m,n,k values using formula m(t+1)=ROUND(m(t)*Base,0)+offset +def mnk_generator(): + k_1 = 1 + incr_k = 20 + while (k_1 <= max_elem): + n_1 = 1 + incr_n = 20 + while (n_1 <= max_elem): + m_1 = 1 + incr_m = 20 + while (m_1 <= max_elem): + mnk_array.append([m_1, n_1, k_1]) + if (m_1 == 1): + m_1 = m_1 + 9 + else: + m_1 = m_1 + incr_m + if (n_1 == 1): + n_1 = n_1 + 9 + else: + n_1 = n_1 + incr_n + if (k_1 == 1): + k_1 = k_1 + 9 + else: + k_1 = k_1 + incr_k + +def data_gen(): + mnk_generator() + + fout = open(out_file_name, "w") + + for ele in mnk_array: + fout.write("i r " + str(ele[0]) + " " + str(ele[1]) + " " + str(ele[2]) + " " +\ + str(ele[2]) + " " + str(ele[1]) + " " + str(ele[1]) + "\n") + + fout.truncate(fout.tell() - 1) + fout.close() + +##__main__ +data_gen() diff --git a/bench/bench_axpbyv.c b/bench/bench_axpbyv.c index 36a203f696..c962079dd6 100644 --- a/bench/bench_axpbyv.c +++ b/bench/bench_axpbyv.c @@ -97,7 +97,7 @@ int main( int argc, char** argv ) // {function name} {S, D, C, Z} {n} // {alpha_r} {alpha_i} {incx} {beta_r} {beta_i} {incy} - while ( fscanf( fin, "%s %c %ld %lf %lf %ld %lf %lf %ld\n", + while ( fscanf( fin, "%s %c " INT_FS " %lf %lf " INT_FS " %lf %lf " INT_FS "\n", tmp, &dt_ch, &n, &alpha_r, &alpha_i, &incx, &beta_r, &beta_i, &incy ) == 9 ) { diff --git a/bench/bench_copyv.c b/bench/bench_copyv.c index c46ffc6093..7be38907ed 100644 --- a/bench/bench_copyv.c +++ b/bench/bench_copyv.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2021-2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -101,7 +101,7 @@ int main( int argc, char** argv ) inc_t incx, incy; // {S,D,C,Z} {n incx incy} - while (fscanf(fin, "%s %c %ld %ld %ld\n", + while (fscanf(fin, "%s %c " INT_FS INT_FS INT_FS "\n", tmp, &dt_ch, &n, &incx, &incy) == 5) { diff --git a/bench/bench_dotv.c b/bench/bench_dotv.c index 80dcf8e99d..0d39594f72 100644 --- a/bench/bench_dotv.c +++ b/bench/bench_dotv.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2021-2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -104,8 +104,7 @@ int main( int argc, char** argv ) // {S,D,C,Z} {n incx incy} - - while (fscanf(fin, "%s %c %ld %ld %ld\n", + while (fscanf(fin, "%s %c " INT_FS INT_FS INT_FS "\n", tmp, &dt_ch, &n, &incx, &incy) == 5) { diff --git a/bench/bench_gemm.c b/bench/bench_gemm.c index 8258b61d18..908ce0fca5 100755 --- a/bench/bench_gemm.c +++ b/bench/bench_gemm.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020-2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020-2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -129,8 +129,7 @@ int main( int argc, char** argv ) // beta_real, beta_imag, ldc, // // number of threads, execution time, gflops ---> ignored by bench - - while (fscanf(fin, "%s %c %c %c %ld %ld %ld %lf %lf %ld %ld %lf %lf %ld[^\n]", + while (fscanf(fin, "%s %c %c %c " INT_FS INT_FS INT_FS " %lf %lf " INT_FS INT_FS " %lf %lf " INT_FS"[^\n]", api_name, &dt_ch, &transA_c, &transB_c, &m, &n, &k, &alpha_r, &alpha_i, &lda, &ldb, &beta_r, &beta_i, &ldc) == 14) { diff --git a/bench/bench_gemmt.c b/bench/bench_gemmt.c index 621c9288c7..ad24593747 100644 --- a/bench/bench_gemmt.c +++ b/bench/bench_gemmt.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2020-21, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020-22, Advanced Micro Devices, Inc. All rights reserved. modification, are permitted provided that the following conditions are met: @@ -122,7 +122,7 @@ int main( int argc, char** argv ) stor_scheme = 'C'; // since logs are collected at BLAS APIs // {S,D,C,Z} {triangC : l or u} {n k lda ldb ldc transa transb alpha_real alpha_imaginary beta_real, beta_imaginary} - while (fscanf(fin,"%s %c %c %ld %ld %lu %lu %lu %c %c %lf %lf %lf %lf\n",\ + while (fscanf(fin,"%s %c %c " INT_FS INT_FS UINT_FS UINT_FS UINT_FS " %c %c %lf %lf %lf %lf\n",\ tmp, &dt_ch, &uplo_c, &n, &k,\ &lda, &ldb, &ldc, &transA_c, &transB_c, \ &alpha_r, &alpha_i, &beta_r, &beta_i) == 14) diff --git a/bench/bench_gemv.c b/bench/bench_gemv.c index acc4598000..9f06bf8efb 100755 --- a/bench/bench_gemv.c +++ b/bench/bench_gemv.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2021-2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -112,8 +112,7 @@ int main( int argc, char** argv ) // {S,D,C,Z} {transa m n alpha lda, incx, beta, incy} - - while (fscanf(fin, "%s %c %c %ld %ld %lf %lf %ld %ld %lf %lf %ld\n", + while (fscanf(fin, "%s %c %c " INT_FS INT_FS " %lf %lf " INT_FS INT_FS " %lf %lf " INT_FS "\n", tmp, &dt_ch, &transA, &m, &n, &alpha_r, &alpha_i, &lda,\ &incx, &beta_r, &beta_i, &incy) == 12) { diff --git a/bench/bench_ger.c b/bench/bench_ger.c index fb50c94265..2c8981a682 100644 --- a/bench/bench_ger.c +++ b/bench/bench_ger.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2021-22, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2021-2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -116,8 +116,7 @@ int main( int argc, char** argv ) #endif // {S,D,C,Z} {transa m n alpha incx incy lda} - - while (fscanf(fin, "%s %c %ld %ld %lf %lf %ld %ld %ld\n", + while (fscanf(fin, "%s %c " INT_FS INT_FS " %lf %lf " INT_FS INT_FS INT_FS "\n", tmp, &dt_ch, &m, &n, &alpha_r, &alpha_i, &incx, &incy, &lda) == 9) { diff --git a/bench/bench_nrm2.c b/bench/bench_nrm2.c new file mode 100644 index 0000000000..ae79eb3307 --- /dev/null +++ b/bench/bench_nrm2.c @@ -0,0 +1,241 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name of The University of Texas nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifdef WIN32 +#include +#else +#include +#endif +#include "blis.h" + + +// Benchmark application to process aocl logs generated by BLIS library. +#ifndef DT +#define DT BLIS_DOUBLE +#endif + + +#define AOCL_MATRIX_INITIALISATION + +//#define BLIS_ENABLE_CBLAS + +/* For BLIS since logs are collected at BLAS interfaces + * we disable cblas interfaces for this benchmark application + */ + +/* #ifdef BLIS_ENABLE_CBLAS */ +/* #define CBLAS */ +/* #endif */ + +int main( int argc, char** argv ) +{ + obj_t x; + obj_t normf; + dim_t p_inc = 0; // to keep track of number of inputs + num_t dt; + char dt_ch; + int r, n_repeats; + + double dtime; + double dtime_save; + double gflops; + + FILE* fin = NULL; + FILE* fout = NULL; + + n_repeats = N_REPEAT; // This macro will get from Makefile. + + dt = DT; + + if ( argc < 3 ) + { + printf("Usage: ./test_nrm2_XX.x input.csv output.csv [number_repeats]\n"); + exit(1); + } + fin = fopen( argv[1], "r" ); + if ( argc == 4 ) + { + n_repeats = atoi(argv[3]); + } + if ( fin == NULL ) + { + printf("Error opening the file %s\n", argv[1]); + exit(1); + } + fout = fopen(argv[2], "w"); + if (fout == NULL) + { + printf("Error opening output file %s\n", argv[2]); + exit(1); + } + + fprintf(fout, "Dt\t n\t incx\t gflops\n"); + dim_t n; + inc_t incx; + char tmp[256]; // to store function name, line no present in logs. + + + // {S,D,C,Z} {n incx} + while (fscanf(fin, "%s %c" INT_FS INT_FS "\n", + tmp, &dt_ch, &n, &incx) == 4) + { + +#ifdef PRINT + fprintf (stdout, "Input = %s %c %ld %ld\n", + tmp, dt_ch, n, incx); +#endif + + if (dt_ch == 'D' || dt_ch == 'd') dt = BLIS_DOUBLE; + else if (dt_ch == 'Z' || dt_ch == 'z') dt = BLIS_DCOMPLEX; + else if (dt_ch == 'S' || dt_ch == 's') dt = BLIS_FLOAT; + else if (dt_ch == 'C' || dt_ch == 'c') dt = BLIS_SCOMPLEX; + else + { + printf("Invalid data type %c\n", dt_ch); + continue; + } + + // Create objects with required sizes and strides. + + // The ?nrm2 routines compute the Euclidean norm of a vector X + // norm = ||X|| + // defined as the square root of the sum of squares of the vector elements + // where: + // X is an n-element vector. + + bli_obj_create( dt, n, 1, incx, 1, &x ); + bli_obj_create_1x1( dt, &normf ); +#ifdef AOCL_MATRIX_INITIALISATION + bli_randv( &x ); +#endif + dtime_save = DBL_MAX; + + for ( r = 0; r < n_repeats; ++r ) + { + +#ifdef PRINT + bli_printm( "x", &x, "%4.1f", "" ); +#endif + dtime = bli_clock(); + +#ifdef BLIS + bli_normfv(&x, &normf); +#else // BLIS Interface + + // Set data type independent inputs for BLAS and + // CBLAS API's + + f77_int nn = bli_obj_length( &x ); + f77_int blas_incx = bli_obj_vector_inc( &x ); + + if ( bli_is_float( dt ) ){ + float* xp = bli_obj_buffer( &x ); + float* normfp = bli_obj_buffer( &normf ); +#ifdef CBLAS + *normfp = cblas_snrm2( nn, xp, blas_incx ); +#else // cblas snrm2 + *normfp = snrm2_( &nn, xp, &blas_incx); +#endif // cblas snrm2 + } + else if ( bli_is_double( dt ) ) + { + + double* xp = bli_obj_buffer( &x ); + double* normfp = bli_obj_buffer( &normf ); + +#ifdef CBLAS + *normfp = cblas_dnrm2( nn, xp, blas_incx ); + +#else // cblas dnrm2 + *normfp = dnrm2_( &nn, xp, &blas_incx); +#endif // cblas dnrm2 + } + else if ( bli_is_scomplex( dt ) ) + { + scomplex* xp = bli_obj_buffer( &x ); + float* normfp = bli_obj_buffer( &normf ); + +#ifdef CBLAS + *normfp = cblas_scnrm2( nn, xp, blas_incx ); +#else // cblas cnrm2 + *normfp = scnrm2_( &nn, xp, &blas_incx); +#endif // cblas cnrm2 + } + else if ( bli_is_dcomplex( dt ) ) + { + dcomplex* xp = bli_obj_buffer( &x ); + double* normfp = bli_obj_buffer( &normf ); +#ifdef CBLAS + *normfp = cblas_dznrm2( nn, xp, blas_incx ); +#else // cblas znrm2 + *normfp = dznrm2_( &nn, xp, &blas_incx); +#endif // cblas znrm2 + } + +#endif // BLIS Interface + +#ifdef PRINT + bli_printm( "x after", &x "%4.1f", "" ); + exit(1); +#endif + + dtime_save = bli_clock_min_diff( dtime_save, dtime ); + } + + gflops = (2*n) / ( dtime_save * 1.0e9 ); + + if ( bli_is_complex( dt ) ) gflops *= 2.0; + + printf( "data_nrm2_%s", BLAS ); + + p_inc++; + printf("( %2lu, 1:4 ) = [ %4lu %7.2f ];\n", + (unsigned long)(p_inc), + (unsigned long)n, + gflops); + + fprintf (fout, "%c %ld %ld %6.3f\n", + dt_ch, n, incx, gflops); + + fflush(fout); + + bli_obj_free( &x ); + } + + //bli_finalize(); + fclose(fin); + fclose(fout); + + return 0; +} diff --git a/bench/bench_scalv.c b/bench/bench_scalv.c index 404d5078f5..b8cd6241c1 100644 --- a/bench/bench_scalv.c +++ b/bench/bench_scalv.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2021-2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -105,8 +105,7 @@ int main( int argc, char** argv ) // {S,D,C,Z} {alpha n incx} - - while (fscanf(fin, "%s %c %lf %lf %ld %ld\n", + while (fscanf(fin, "%s %c %lf %lf " INT_FS INT_FS "\n", tmp, &dt_ch, &alpha_r, &alpha_i, &n, &incx) == 6) { diff --git a/bench/bench_swapv.c b/bench/bench_swapv.c index 16aafdaaed..34af6b7975 100644 --- a/bench/bench_swapv.c +++ b/bench/bench_swapv.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2021-2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -103,7 +103,7 @@ int main( int argc, char** argv ) char tmp[256]; // to store function name, line no present in logs. // {S,D,C,Z} {n incx incy} - while (fscanf(fin, "%s %c %ld %ld %ld\n", + while (fscanf(fin, "%s %c " INT_FS INT_FS INT_FS "\n", tmp, &dt_ch, &n, &incx, &incy) == 5) { diff --git a/bench/bench_syrk.c b/bench/bench_syrk.c index 017b010dfc..b65db83aa5 100644 --- a/bench/bench_syrk.c +++ b/bench/bench_syrk.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2021-2022, Advanced Micro Devices, Inc. All rights reserved. modification, are permitted provided that the following conditions are met: @@ -120,7 +120,7 @@ int main( int argc, char** argv ) stor_scheme = 'C'; // since logs are collected at BLAS APIs // {S,D,C,Z}{ uploc, transa, n, k, alpha_real, alpha_imag, lda, beta_real, beta_imag, ldc} - while (fscanf(fin, "%s %c %c %c %ld %ld %lf %lf %lu %lf %lf %lu\n",\ + while (fscanf(fin, "%s %c %c %c " INT_FS INT_FS " %lf %lf " UINT_FS " %lf %lf " UINT_FS "\n",\ tmp, &dt_ch, &uplo_c, &transA_c, &n, &k, &alpha_r,\ &alpha_i, &lda, &beta_r, &beta_i, &ldc) == 12) { diff --git a/bench/bench_trsm.c b/bench/bench_trsm.c index a7d62ebecc..7014bd4753 100644 --- a/bench/bench_trsm.c +++ b/bench/bench_trsm.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020-2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020-2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: @@ -62,7 +62,7 @@ int main( int argc, char** argv ) dim_t p_inc = 0; // to keep track of number of inputs num_t dt = BLIS_DOUBLE; dim_t r, n_repeats; - f77_char side; + side_t side; uplo_t uploa; trans_t transa; diag_t diaga; @@ -101,7 +101,7 @@ int main( int argc, char** argv ) f77_char dt_type_arg, side_arg, uploa_arg, transa_arg, diaga_arg; f77_char logline[255]; // input order: {S,D,C,Z} {side, uplo, transa, diag, m, n, lda, ldb, alphaR, alphaI} - while(fscanf(fin, "%s %c %c %c %c %c %ld %ld %ld %ld %lf %lf\n", + while(fscanf(fin, "%s %c %c %c %c %c " INT_FS INT_FS INT_FS INT_FS " %lf %lf\n", logline, &dt_type_arg, &side_arg, &uploa_arg, &transa_arg, &diaga_arg, &m, &n, &lda, &ldb, &alphaR, &alphaI) == 12) { @@ -191,7 +191,7 @@ int main( int argc, char** argv ) #endif dtime = bli_clock(); #ifdef BLIS - bli_trsm( &side, + bli_trsm( side, &alpha, &a, &b ); diff --git a/bench/bench_trsv.c b/bench/bench_trsv.c index ca18d3fdc9..ddf3ea187a 100644 --- a/bench/bench_trsv.c +++ b/bench/bench_trsv.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2021-2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -121,8 +121,7 @@ int main( int argc, char** argv ) fprintf(fout, "Dt uploa\t transa\t diaga\t m\t lda\t incx\t gflops\n"); // {S,D,C,Z} {uploa transa diaga m lda, incx} - - while (fscanf(fin, "%s %c %c %c %c %ld %ld %ld\n", + while (fscanf(fin, "%s %c %c %c %c " INT_FS INT_FS INT_FS "\n", tmp, &dt_ch, &uploa_c, &transA, &diaga_c, &m, &lda, &incx) == 8) { diff --git a/bench/inputnrm2.txt b/bench/inputnrm2.txt new file mode 100644 index 0000000000..567d6e4691 --- /dev/null +++ b/bench/inputnrm2.txt @@ -0,0 +1,42 @@ +dnrm2:171: D 2 1 +dnrm2:171: D 4 1 +dnrm2:171: D 8 1 +dnrm2:171: D 42 1 +dnrm2:171: D 64 1 +dnrm2:171: D 87 1 +dnrm2:171: D 100 1 +dnrm2:171: D 128 1 +dnrm2:171: D 189 1 +dnrm2:171: D 208 1 +dnrm2:171: D 256 1 +dnrm2:171: D 313 1 +dnrm2:171: D 512 1 +dnrm2:171: D 718 1 +dnrm2:171: D 932 1 +dnrm2:171: D 1024 1 +dnrm2:171: D 1895 1 +dnrm2:171: D 2048 1 +dnrm2:171: D 3275 1 +dnrm2:171: D 4096 1 +dnrm2:171: D 6749 1 +dnrm2:171: D 8192 1 +dnrm2:171: D 10001 1 +dnrm2:171: D 16384 1 +dnrm2:171: D 20976 1 +dnrm2:171: D 32768 1 +dnrm2:171: D 56841 1 +dnrm2:171: D 65536 1 +dnrm2:171: D 8 3 +dnrm2:171: D 64 7 +dnrm2:171: D 87 12 +dnrm2:171: D 189 9 +dnrm2:171: D 313 3 +dnrm2:171: D 718 5 +dnrm2:171: D 1024 2 +dnrm2:171: D 3275 4 +dnrm2:171: D 4096 7 +dnrm2:171: D 8192 5 +dnrm2:171: D 16384 11 +dnrm2:171: D 20976 3 +dnrm2:171: D 56841 19 +dnrm2:171: D 65536 6 \ No newline at end of file diff --git a/blastest/f2c/rdfmt.c b/blastest/f2c/rdfmt.c index 6349e3f3fd..0d8a0bf12e 100644 --- a/blastest/f2c/rdfmt.c +++ b/blastest/f2c/rdfmt.c @@ -249,9 +249,13 @@ static int rd_F(ufloat *p, int w, int d, ftnlen len) } while(ch == ' ') { blankdrop: - if (!w--) goto zero; GET(ch); } - while(ch == '0') - { if (!w--) goto zero; GET(ch); } + if (!w--) goto zero; + GET(ch); + } + while(ch == '0') { + if (!w--) goto zero; + GET(ch); + } if (ch == ' ' && f__cblank) goto blankdrop; scale1 = f__scale; @@ -262,7 +266,7 @@ static int rd_F(ufloat *p, int w, int d, ftnlen len) digloop1e: if (!w--) goto done; GET(ch); - } + } if (ch == ' ') { if (f__cblank) { ch = '0'; goto digloop1; } diff --git a/build/bli_addon.h.in b/build/bli_addon.h.in new file mode 100644 index 0000000000..36a8e29bd1 --- /dev/null +++ b/build/bli_addon.h.in @@ -0,0 +1,47 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef BLIS_ADDON_H +#define BLIS_ADDON_H + +#if @enable_addons@ +#define BLIS_ENABLE_ADDONS +#else +#define BLIS_DISABLE_ADDONS +#endif + +// Enabled addons +@addon_list_includes@ + +#endif diff --git a/build/bli_config.h.in b/build/bli_config.h.in index 73f51baed2..6c17fc5e74 100644 --- a/build/bli_config.h.in +++ b/build/bli_config.h.in @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2021, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -194,4 +194,10 @@ #define BLIS_DISABLE_COMPLEX_RETURN_INTEL #endif +#if @disable_blis_arch_type@ +#define DISABLE_BLIS_ARCH_TYPE +#endif + +#define __blis_arch_type_name "@rename_blis_arch_type@" + #endif diff --git a/build/bli_win_config.h.in b/build/bli_win_config.h.in index 6c61b2b1a4..24e1fc3d59 100644 --- a/build/bli_win_config.h.in +++ b/build/bli_win_config.h.in @@ -1,5 +1,5 @@ /* - * Copyright (C) 2020-2021, Advanced Micro Devices, Inc. All Rights Reserved + * Copyright (C) 2020-2022, Advanced Micro Devices, Inc. All Rights Reserved */ #ifndef BLIS_CONFIG_H @@ -47,4 +47,8 @@ #cmakedefine BLIS_ENABLE_COMPLEX_RETURN_INTEL +#cmakedefine DISABLE_BLIS_ARCH_TYPE + +#cmakedefine __blis_arch_type_name "@rename_blis_arch_type@" + #endif diff --git a/build/blis_ref_kernel_mirror.py b/build/blis_ref_kernel_mirror.py index 8ef90a12af..834de1cee9 100644 --- a/build/blis_ref_kernel_mirror.py +++ b/build/blis_ref_kernel_mirror.py @@ -76,7 +76,9 @@ def remove_lines_in_file(filename): 'add_subdirectory(${CMAKE_BINARY_DIR}/ref_kernels/zen2 ' '${CMAKE_BINARY_DIR}/ref_kernels/zen2)\n' 'add_subdirectory(${CMAKE_BINARY_DIR}/ref_kernels/zen3 ' - '${CMAKE_BINARY_DIR}/ref_kernels/zen3)\nelse()', '\n') + '${CMAKE_BINARY_DIR}/ref_kernels/zen3)\n' + 'add_subdirectory(${CMAKE_BINARY_DIR}/ref_kernels/zen4 ' + '${CMAKE_BINARY_DIR}/ref_kernels/zen4)\nelse()', '\n') data = file_content.replace('endif()', '\n') with open(filename, 'w') as fd: fd.write(data + '\n') diff --git a/build/config.mk.in b/build/config.mk.in index a880074e8f..eddb69f705 100644 --- a/build/config.mk.in +++ b/build/config.mk.in @@ -183,6 +183,10 @@ MK_ENABLE_CBLAS := @enable_cblas@ # Whether libblis will depend on libmemkind for certain memory allocations. MK_ENABLE_MEMKIND := @enable_memkind@ +# The names of the addons to include when building BLIS. If empty, no addons +# will be included. +ADDON_LIST := @addon_list@ + # The name of a sandbox defining an alternative gemm implementation. If empty, # no sandbox will be used and the conventional gemm implementation will remain # enabled. diff --git a/common.mk b/common.mk index 00b1d4354e..220e8ccaa8 100644 --- a/common.mk +++ b/common.mk @@ -5,7 +5,7 @@ # libraries. # # Copyright (C) 2014, The University of Texas at Austin -# Copyright (C) 2020-2021, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2020-2022, Advanced Micro Devices, Inc. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are @@ -159,20 +159,46 @@ get-kernel-cflags-for = $(strip $(call load-var-for,CKOPTFLAGS,$(1)) \ $(BUILD_SYMFLAGS) \ ) +# When compiling addons, we use flags similar to those of general framework +# source. This ensures that the same code can be linked and run across various +# sub-configurations. +get-addon-c99flags-for = $(strip $(call load-var-for,COPTFLAGS,$(1)) \ + $(call get-noopt-cflags-for,$(1)) \ + $(CADDONINCFLAGS) \ + $(BUILD_CPPFLAGS) \ + $(BUILD_SYMFLAGS) \ + ) +get-addon-cxxflags-for = $(strip $(call load-var-for,COPTFLAGS,$(1)) \ + $(call get-noopt-cxxflags-for,$(1)) \ + $(CADDONINCFLAGS) \ + $(BUILD_CPPFLAGS) \ + $(BUILD_SYMFLAGS) \ + ) +# When compiling addon kernels, we use flags similar to those of kernels +# flags, except we also include the addon header paths. +get-addon-kernel-c99flags-for = $(strip $(call load-var-for,CKOPTFLAGS,$(1)) \ + $(call load-var-for,CKVECFLAGS,$(1)) \ + $(call get-noopt-cflags-for,$(1)) \ + $(CADDONINCFLAGS) \ + $(BUILD_CPPFLAGS) \ + $(BUILD_SYMFLAGS) \ + ) + # When compiling sandboxes, we use flags similar to those of general framework # source. This ensures that the same code can be linked and run across various -# sub-configurations. (If we switch to using refkern/kernel flags, we should -# prevent enabling sandboxes for umbrella families by verifying that -# config_list == config_name if --enable-sandbox is given.) +# sub-configurations. (NOTE: If we ever switch to using refkernel or kernel +# flags, we should prevent enabling sandboxes for umbrella families by verifying +# that config_list == config_name if --enable-sandbox is given. THIS ALSO +# APPLIES TO ADDONS ABOVE.) get-sandbox-c99flags-for = $(strip $(call load-var-for,COPTFLAGS,$(1)) \ $(call get-noopt-cflags-for,$(1)) \ - $(CSBOXINCFLAGS) \ + $(CSANDINCFLAGS) \ $(BUILD_CPPFLAGS) \ $(BUILD_SYMFLAGS) \ ) get-sandbox-cxxflags-for = $(strip $(call load-var-for,COPTFLAGS,$(1)) \ $(call get-noopt-cxxflags-for,$(1)) \ - $(CSBOXINCFLAGS) \ + $(CSANDINCFLAGS) \ $(BUILD_CPPFLAGS) \ $(BUILD_SYMFLAGS) \ ) @@ -191,15 +217,18 @@ get-user-cflags-for = $(strip $(call load-var-for,COPTFLAGS,$(1)) \ # Define functions that return messages appropriate for each non-verbose line # of compilation output. -get-noopt-text = "(CFLAGS for no optimization)" -get-refinit-text-for = "('$(1)' CFLAGS for ref. kernel init)" -get-refkern-text-for = "('$(1)' CFLAGS for ref. kernels)" -get-config-text-for = "('$(1)' CFLAGS for config code)" -get-frame-text-for = "('$(1)' CFLAGS for framework code)" -get-aocldtl-text-for = "('$(1)' CFLAGS for AOCL debug and trace code)" -get-kernel-text-for = "('$(1)' CFLAGS for kernels)" -get-sandbox-c99text-for = "('$(1)' CFLAGS for sandboxes)" -get-sandbox-cxxtext-for = "('$(1)' CXXFLAGS for sandboxes)" +get-noopt-text = "(CFLAGS for no optimization)" +get-refinit-text-for = "('$(1)' CFLAGS for ref. kernel init)" +get-refkern-text-for = "('$(1)' CFLAGS for ref. kernels)" +get-config-text-for = "('$(1)' CFLAGS for config code)" +get-frame-text-for = "('$(1)' CFLAGS for framework code)" +get-aocldtl-text-for = "('$(1)' CFLAGS for AOCL debug and trace code)" +get-kernel-text-for = "('$(1)' CFLAGS for kernels)" +get-addon-c99text-for = "('$(1)' CFLAGS for addons)" +get-addon-cxxtext-for = "('$(1)' CXXFLAGS for addons)" +get-addon-kernel-text-for = "('$(1)' CFLAGS for addon kernels)" +get-sandbox-c99text-for = "('$(1)' CFLAGS for sandboxes)" +get-sandbox-cxxtext-for = "('$(1)' CXXFLAGS for sandboxes)" @@ -212,6 +241,10 @@ get-sandbox-cxxtext-for = "('$(1)' CXXFLAGS for sandboxes)" files-that-contain = $(strip $(foreach f, $(1), $(if $(findstring $(2),$(f)),$(f),))) files-that-dont-contain = $(strip $(foreach f, $(1), $(if $(findstring $(2),$(f)),,$(f)))) +# Define a function that removes duplicate strings *without* using the sort +# function. +rm-dups = $(if $1,$(firstword $1) $(call rm-dups,$(filter-out $(firstword $1),$1))) + # # --- Include makefile configuration file -------------------------------------- @@ -297,6 +330,7 @@ FRAME_DIR := frame AOCLDTL_DIR := aocl_dtl REFKERN_DIR := ref_kernels KERNELS_DIR := kernels +ADDON_DIR := addon SANDBOX_DIR := sandbox OBJ_DIR := obj LIB_DIR := lib @@ -313,12 +347,13 @@ REFNM := ref # Source suffixes. CONFIG_SRC_SUFS := c - KERNELS_SRC_SUFS := c s S - FRAME_SRC_SUFS := c AOCLDTL_SRC_SUFS := c +ADDON_C99_SUFS := c +ADDON_CXX_SUFS := cc cpp cxx +ADDON_SRC_SUFS := $(ADDON_C99_SUFS) $(ADDON_CXX_SUFS) SANDBOX_C99_SUFS := c SANDBOX_CXX_SUFS := cc cpp cxx @@ -328,6 +363,9 @@ SANDBOX_SRC_SUFS := $(SANDBOX_C99_SUFS) $(SANDBOX_CXX_SUFS) FRAME_HDR_SUFS := h AOCLDTL_HDR_SUFS := h +ADDON_H99_SUFS := h +ADDON_HXX_SUFS := hh hpp hxx +ADDON_HDR_SUFS := $(ADDON_H99_SUFS) $(ADDON_HXX_SUFS) SANDBOX_H99_SUFS := h SANDBOX_HXX_SUFS := hh hpp hxx @@ -335,10 +373,12 @@ SANDBOX_HDR_SUFS := $(SANDBOX_H99_SUFS) $(SANDBOX_HXX_SUFS) # Combine all header suffixes and remove duplicates via sort(). ALL_HDR_SUFS := $(sort $(FRAME_HDR_SUFS) \ + $(ADDON_HDR_SUFS) \ $(SANDBOX_HDR_SUFS) \ $(AOCLDTL_HDR_SUFS)) ALL_H99_SUFS := $(sort $(FRAME_HDR_SUFS) \ + $(ADDON_HDR_SUFS) \ $(SANDBOX_H99_SUFS) \ $(AOCLDTL_HDR_SUFS)) @@ -366,12 +406,14 @@ SHELL := bash # Construct paths to the four primary directories of source code: # the config directory, general framework code, reference kernel code, -# and optimized kernel code. +# and optimized kernel code. Also process paths for addon and sandbox +# directories. CONFIG_PATH := $(DIST_PATH)/$(CONFIG_DIR) FRAME_PATH := $(DIST_PATH)/$(FRAME_DIR) AOCLDTL_PATH := $(DIST_PATH)/$(AOCLDTL_DIR) REFKERN_PATH := $(DIST_PATH)/$(REFKERN_DIR) KERNELS_PATH := $(DIST_PATH)/$(KERNELS_DIR) +ADDON_PATH := $(DIST_PATH)/$(ADDON_DIR) SANDBOX_PATH := $(DIST_PATH)/$(SANDBOX_DIR) # Construct paths to some optional C++ template headers contributed by AMD. @@ -386,6 +428,7 @@ FRAME_FRAG_PATH := ./obj/$(CONFIG_NAME)/$(FRAME_DIR) AOCLDTL_FRAG_PATH := ./obj/$(CONFIG_NAME)/$(AOCLDTL_DIR) REFKERN_FRAG_PATH := ./obj/$(CONFIG_NAME)/$(REFKERN_DIR) KERNELS_FRAG_PATH := ./obj/$(CONFIG_NAME)/$(KERNELS_DIR) +ADDON_FRAG_PATH := ./obj/$(CONFIG_NAME)/$(ADDON_DIR) SANDBOX_FRAG_PATH := ./obj/$(CONFIG_NAME)/$(SANDBOX_DIR) @@ -863,6 +906,7 @@ MK_KERNELS_SRC := MK_REFKERN_SRC := MK_FRAME_SRC := MK_AOCLDTL_SRC := +MK_ADDON_SRC := MK_SANDBOX_SRC := # -- config -- @@ -914,6 +958,24 @@ PARENT_PATH := $(OBJ_DIR)/$(CONFIG_NAME) -include $(addsuffix /$(FRAGMENT_MK), $(FRAME_FRAG_PATH)) -include $(addsuffix /$(FRAGMENT_MK), $(AOCLDTL_FRAG_PATH)) +# -- addon -- + +# Construct paths to each addon. +# NOTE: If $(ADDON_LIST) is empty (because no addon was enabled at configure- +# time) then $(ADDON_PATHS) will also be empty, which will cause no fragments +# to be included. +ADDON_PATHS := $(addprefix $(ADDON_FRAG_PATH)/, $(ADDON_LIST)) + +# This variable is used by the include statements as they recursively include +# one another. For the 'addons' directory, we initialize it to that directory +# in preparation to include the fragments in the configuration sub-directory. +PARENT_SRC_PATH := $(ADDON_PATH) +PARENT_PATH := $(ADDON_FRAG_PATH) + +# Recursively include the makefile fragments in each of the addons sub- +# directories. +-include $(addsuffix /$(FRAGMENT_MK), $(ADDON_PATHS)) + # -- sandbox -- # Construct paths to each sandbox. (At present, there can be only one.) @@ -931,6 +993,8 @@ PARENT_PATH := $(SANDBOX_FRAG_PATH) # Recursively include the makefile fragments in the sandbox sub-directory. -include $(addsuffix /$(FRAGMENT_MK), $(SANDBOX_PATHS)) +# -- post-processing -- + # Create a list of the makefile fragments using the variable into which each # of the above include statements accumulated their directory paths. MAKEFILE_FRAGMENTS := $(addsuffix /$(FRAGMENT_MK), $(FRAGMENT_DIR_PATHS)) @@ -949,14 +1013,14 @@ endif # # Define a function that will expand all of the directory paths given in $(1) -# to actual filepaths using the list of suffixes provided $(2). +# to actual filepaths using the list of suffixes provided in $(2). get-filepaths = $(strip $(foreach path, $(1), \ $(foreach suf, $(2), \ $(wildcard $(path)/*.$(suf)) \ ) ) ) # Define a function that will expand all of the directory paths given in $(1) -# to actual filepaths using the list of suffixes provided $(2), taking only +# to actual filepaths using the list of suffixes provided in $(2), taking only # the first expansion from each directory with at least one file matching # the current suffix. Finally, strip the filenames from all resulting files, # returning only the directory paths. @@ -966,20 +1030,29 @@ get-dirpaths = $(dir $(foreach path, $(1), \ $(wildcard $(path)/*.$(suf)) \ ) ) ) ) -# We'll use two directory lists. The first is a list of all of the directories -# in which makefile fragments were generated (plus the current directory). The -# second is the subset of the first that begins with the sandbox root path. +# We'll use three directory lists. The first is a list of all of the directories +# in which makefile fragments were generated, plus the current directory. (The +# current directory is needed so we include bli_config.h and bli_addon.h in the +# processing of header files.) The second and third are subsets of the first +# that begins with the addon and sandbox root paths, respectively. ALLFRAG_DIR_PATHS := . $(FRAGMENT_DIR_PATHS) +ADDON_DIR_PATHS := $(filter $(ADDON_PATH)/%,$(ALLFRAG_DIR_PATHS)) SANDBOX_DIR_PATHS := $(filter $(SANDBOX_PATH)/%,$(ALLFRAG_DIR_PATHS)) ALL_H99_FILES := $(call get-filepaths,$(ALLFRAG_DIR_PATHS),$(ALL_H99_SUFS)) -FRAME_H99_FILES := $(filter-out $(SANDBOX_PATH)/%,$(ALL_H99_FILES)) +FRAME_H99_FILES := $(filter-out $(ADDON_PATH)/%, \ + $(filter-out $(SANDBOX_PATH)/%, \ + $(ALL_H99_FILES) \ + ) ) -ALL_H99_DIRPATHS := $(call get-dirpaths,$(ALLFRAG_DIR_PATHS),$(ALL_H99_SUFS)) +ALL_H99_DIRPATHS := $(call get-dirpaths,$(ALLFRAG_DIR_PATHS),$(ALL_H99_SUFS)) -SANDBOX_H99_FILES := $(call get-filepaths,$(SANDBOX_DIR_PATHS),$(SANDBOX_H99_SUFS)) -SANDBOX_HXX_FILES := $(call get-filepaths,$(SANDBOX_DIR_PATHS),$(SANDBOX_HXX_SUFS)) +ADDON_H99_FILES := $(call get-filepaths,$(ADDON_DIR_PATHS),$(ADDON_H99_SUFS)) +ADDON_HXX_FILES := $(call get-filepaths,$(ADDON_DIR_PATHS),$(ADDON_HXX_SUFS)) +ADDON_HDR_DIRPATHS := $(call get-dirpaths,$(ADDON_DIR_PATHS),$(ALL_HDR_SUFS)) +SANDBOX_H99_FILES := $(call get-filepaths,$(SANDBOX_DIR_PATHS),$(SANDBOX_H99_SUFS)) +SANDBOX_HXX_FILES := $(call get-filepaths,$(SANDBOX_DIR_PATHS),$(SANDBOX_HXX_SUFS)) SANDBOX_HDR_DIRPATHS := $(call get-dirpaths,$(SANDBOX_DIR_PATHS),$(ALL_HDR_SUFS)) @@ -1021,6 +1094,7 @@ BLIS_H_FLAT := $(BASE_INC_PATH)/$(BLIS_H) # header files. CBLAS_H := cblas.h CBLAS_H_SRC_PATH := $(filter %/$(CBLAS_H), $(FRAME_H99_FILES)) +CBLAS_H_DIRPATH := $(dir $(CBLAS_H_SRC_PATH)) # Construct the path to what will be the intermediate flattened/monolithic # cblas.h file. @@ -1032,8 +1106,8 @@ CBLAS_H_FLAT := $(BASE_INC_PATH)/$(CBLAS_H) # # Obtain a list of header files #included inside of the bli_cntx_ref.c file. -# Paths to these files will be needed when compiling with the monolithic -# header. +# Due to the way that bli_cntx_ref.c uses headers and macros, paths to these +# files will be needed when compiling bli_cntx_ref.c with the monolithic header. ifeq ($(strip $(SHARE_PATH)),.) REF_KER_SRC := $(DIST_PATH)/$(REFKERN_DIR)/bli_cntx_ref.c REF_KER_HEADERS := $(shell $(GREP) "\#include" $(REF_KER_SRC) | sed -e "s/\#include [\"<]\([a-zA-Z0-9\_\.\/\-]*\)[\">].*/\1/g" | $(GREP) -v $(BLIS_H)) @@ -1041,9 +1115,10 @@ endif # Match each header found above with the path to that header, and then strip # leading, trailing, and internal whitespace. -REF_KER_H_PATHS := $(strip $(foreach header, $(REF_KER_HEADERS), \ - $(dir $(filter %/$(header), \ - $(FRAME_H99_FILES))))) +REF_KER_H_PATHS := $(call rm-dups,$(strip \ + $(foreach header, $(REF_KER_HEADERS), \ + $(dir $(filter %/$(header), \ + $(FRAME_H99_FILES)))))) # Add -I to each header path so we can specify our include search paths to the # C compiler. Then add frame/include since it's needed for bli_oapi_w[o]_cntx.h. @@ -1055,17 +1130,29 @@ REF_KER_I_PATHS += -I$(DIST_PATH)/frame/include # now #include the monolithic/flattened blis.h instead. CINCFLAGS := -I$(BASE_INC_PATH) $(REF_KER_I_PATHS) +# If CBLAS is enabled, we also include the path to the cblas.h directory so +# that the compiler will be able to find cblas.h as the CBLAS source code is +# being compiled. +ifeq ($(MK_ENABLE_CBLAS),yes) +CINCFLAGS += -I$(CBLAS_H_DIRPATH) +endif + +# Obtain a list of header paths in the configured addons. Then add -I to each +# header path. +CADDONINCFLAGS := $(strip $(patsubst %, -I%, $(ADDON_HDR_DIRPATHS))) + # Obtain a list of header paths in the configured sandbox. Then add -I to each # header path. -CSBOXINCFLAGS := $(strip $(patsubst %, -I%, $(SANDBOX_HDR_DIRPATHS))) +CSANDINCFLAGS := $(strip $(patsubst %, -I%, $(SANDBOX_HDR_DIRPATHS))) # # --- BLIS configuration header definitions ------------------------------------ # -# This file was created by configure, but we need to define it here so we can -# remove it as part of the clean targets. +# These files were created by configure, but we need to define them here so we +# can remove them as part of the clean targets. +BLIS_ADDON_H := ./bli_addon.h BLIS_CONFIG_H := ./bli_config.h diff --git a/config/CMakeLists.txt b/config/CMakeLists.txt index 12568f67f7..7429ff42ee 100644 --- a/config/CMakeLists.txt +++ b/config/CMakeLists.txt @@ -1,6 +1,9 @@ -##Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved. ## +##Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. ## -if(${TARGET_ARCH} STREQUAL zen3) +if(${TARGET_ARCH} STREQUAL zen4) +message("The configuration is : ${TARGET_ARCH}") +add_subdirectory(zen4) +elseif(${TARGET_ARCH} STREQUAL zen3) message("The configuration is : ${TARGET_ARCH}") add_subdirectory(zen3) elseif(${TARGET_ARCH} STREQUAL zen2) @@ -15,6 +18,7 @@ add_subdirectory(generic) add_subdirectory(zen) add_subdirectory(zen2) add_subdirectory(zen3) +add_subdirectory(zen4) elseif(${TARGET_ARCH} STREQUAL haswell) message("The configuration is : ${TARGET_ARCH}") add_subdirectory(haswell) diff --git a/config/amdzen/bli_family_amdzen.h b/config/amdzen/bli_family_amdzen.h index c73409673d..0cf46d5a4e 100644 --- a/config/amdzen/bli_family_amdzen.h +++ b/config/amdzen/bli_family_amdzen.h @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2021-2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -59,7 +59,29 @@ // BLIS), defining this macro as 1 yields better performance. #define AOCL_BLIS_MULTIINSTANCE 0 -//#define BLIS_ENABLE_FAST_MATH +/* + * Override the block sizes in the context to the block sizes used + * by AVX2 GEMM+TRSM kernels, this is needed in Zen4 context as default + * GEMM kernels are AVX512 based and uses different block sizes. + * + * This function should be called in TRSM path before performing + * any packing operations. + * + * Also the context must be restored to default values by calling + * bli_zen4_restore_default_blkszs() before exiting TRSM Path + */ +BLIS_EXPORT_BLIS void bli_zen4_override_trsm_blkszs (cntx_t* cntx); + +/* + * Restore the block sizes to default values needed for zen4 context. + * + * This function should be called to restore the block sizes to there + * default values if they where overriden by calling + * bli_zen4_override_trsm_blkszs() to enable AVX2 GEMM kernels in the + * TRSM path. + * + */ +BLIS_EXPORT_BLIS void bli_zen4_restore_default_blkszs (cntx_t* cntx); #endif diff --git a/config/zen/bli_cntx_init_zen.c b/config/zen/bli_cntx_init_zen.c index 3fea3ea8f9..9d4197712e 100644 --- a/config/zen/bli_cntx_init_zen.c +++ b/config/zen/bli_cntx_init_zen.c @@ -104,11 +104,11 @@ void bli_cntx_init_zen( cntx_t* cntx ) bli_cntx_set_l1v_kers ( 26, -#if 1 + // amaxv BLIS_AMAXV_KER, BLIS_FLOAT, bli_samaxv_zen_int, BLIS_AMAXV_KER, BLIS_DOUBLE, bli_damaxv_zen_int, -#endif + // axpbyv BLIS_AXPBYV_KER, BLIS_FLOAT, bli_saxpbyv_zen_int10, BLIS_AXPBYV_KER, BLIS_DOUBLE, bli_daxpbyv_zen_int10, @@ -116,16 +116,11 @@ void bli_cntx_init_zen( cntx_t* cntx ) BLIS_AXPBYV_KER, BLIS_DCOMPLEX, bli_zaxpbyv_zen_int, // axpyv -#if 0 - BLIS_AXPYV_KER, BLIS_FLOAT, bli_saxpyv_zen_int, - BLIS_AXPYV_KER, BLIS_DOUBLE, bli_daxpyv_zen_int, -#else BLIS_AXPYV_KER, BLIS_FLOAT, bli_saxpyv_zen_int10, BLIS_AXPYV_KER, BLIS_DOUBLE, bli_daxpyv_zen_int10, BLIS_AXPYV_KER, BLIS_SCOMPLEX, bli_caxpyv_zen_int5, BLIS_AXPYV_KER, BLIS_DCOMPLEX, bli_zaxpyv_zen_int5, -#endif // dotv BLIS_DOTV_KER, BLIS_FLOAT, bli_sdotv_zen_int, BLIS_DOTV_KER, BLIS_DOUBLE, bli_ddotv_zen_int, @@ -138,18 +133,18 @@ void bli_cntx_init_zen( cntx_t* cntx ) BLIS_DOTXV_KER, BLIS_DCOMPLEX, bli_zdotxv_zen_int, BLIS_DOTXV_KER, BLIS_SCOMPLEX, bli_cdotxv_zen_int, // scalv -#if 0 - BLIS_SCALV_KER, BLIS_FLOAT, bli_sscalv_zen_int, - BLIS_SCALV_KER, BLIS_DOUBLE, bli_dscalv_zen_int, -#else + BLIS_SCALV_KER, BLIS_FLOAT, bli_sscalv_zen_int10, BLIS_SCALV_KER, BLIS_DOUBLE, bli_dscalv_zen_int10, -#endif + + // swapv BLIS_SWAPV_KER, BLIS_FLOAT, bli_sswapv_zen_int8, BLIS_SWAPV_KER, BLIS_DOUBLE, bli_dswapv_zen_int8, + // copyv BLIS_COPYV_KER, BLIS_FLOAT, bli_scopyv_zen_int, BLIS_COPYV_KER, BLIS_DOUBLE, bli_dcopyv_zen_int, + //set BLIS_SETV_KER, BLIS_FLOAT, bli_ssetv_zen_int, BLIS_SETV_KER, BLIS_DOUBLE, bli_dsetv_zen_int, @@ -231,9 +226,9 @@ void bli_cntx_init_zen( cntx_t* cntx ) // Initialize sup thresholds with architecture-appropriate values. // s d c z - bli_blksz_init_easy( &thresh[ BLIS_MT ], 512, 256, 380, 110 ); + bli_blksz_init_easy( &thresh[ BLIS_MT ], 512, 256, 380, 128 ); bli_blksz_init_easy( &thresh[ BLIS_NT ], 512, 256, 256, 128 ); - bli_blksz_init_easy( &thresh[ BLIS_KT ], 440, 220, 220, 110 ); + bli_blksz_init_easy( &thresh[ BLIS_KT ], 440, 220, 220, 128 ); // Initialize the context with the sup thresholds. bli_cntx_set_l3_sup_thresh diff --git a/config/zen/bli_family_zen.h b/config/zen/bli_family_zen.h index 23d3d608c7..8b31c32ca0 100644 --- a/config/zen/bli_family_zen.h +++ b/config/zen/bli_family_zen.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2018 - 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -53,6 +53,4 @@ #define BLIS_SMALL_MATRIX_A_THRES_M_SYRK 96 #define BLIS_SMALL_MATRIX_A_THRES_N_SYRK 128 -//#define BLIS_ENABLE_FAST_MATH - #endif diff --git a/config/zen2/bli_cntx_init_zen2.c b/config/zen2/bli_cntx_init_zen2.c index 1ecb62ff52..3ce2fced92 100644 --- a/config/zen2/bli_cntx_init_zen2.c +++ b/config/zen2/bli_cntx_init_zen2.c @@ -116,11 +116,11 @@ void bli_cntx_init_zen2( cntx_t* cntx ) bli_cntx_set_l1v_kers ( 26, -#if 1 + // amaxv BLIS_AMAXV_KER, BLIS_FLOAT, bli_samaxv_zen_int, BLIS_AMAXV_KER, BLIS_DOUBLE, bli_damaxv_zen_int, -#endif + // axpbyv BLIS_AXPBYV_KER, BLIS_FLOAT, bli_saxpbyv_zen_int10, BLIS_AXPBYV_KER, BLIS_DOUBLE, bli_daxpbyv_zen_int10, diff --git a/config/zen2/bli_family_zen2.h b/config/zen2/bli_family_zen2.h index dbae9752cc..16fe50609e 100644 --- a/config/zen2/bli_family_zen2.h +++ b/config/zen2/bli_family_zen2.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2019 - 2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2019 - 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -56,6 +56,5 @@ // When running HPL with pure MPI without DGEMM threading (Single-threaded // BLIS), defining this macro as 1 yields better performance. #define AOCL_BLIS_MULTIINSTANCE 0 -//#define BLIS_ENABLE_FAST_MATH #endif diff --git a/config/zen3/bli_cntx_init_zen3.c b/config/zen3/bli_cntx_init_zen3.c index 02e264d277..779bb7277c 100644 --- a/config/zen3/bli_cntx_init_zen3.c +++ b/config/zen3/bli_cntx_init_zen3.c @@ -116,11 +116,11 @@ void bli_cntx_init_zen3( cntx_t* cntx ) bli_cntx_set_l1v_kers ( 26, -#if 1 + // amaxv BLIS_AMAXV_KER, BLIS_FLOAT, bli_samaxv_zen_int, BLIS_AMAXV_KER, BLIS_DOUBLE, bli_damaxv_zen_int, -#endif + // axpbyv BLIS_AXPBYV_KER, BLIS_FLOAT, bli_saxpbyv_zen_int10, BLIS_AXPBYV_KER, BLIS_DOUBLE, bli_daxpbyv_zen_int10, diff --git a/config/zen3/bli_family_zen3.h b/config/zen3/bli_family_zen3.h index 69def1422d..ce84104c52 100644 --- a/config/zen3/bli_family_zen3.h +++ b/config/zen3/bli_family_zen3.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020-2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020 - 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -55,6 +55,4 @@ #define BLIS_SMALL_MATRIX_A_THRES_M_SYRK 96 #define BLIS_SMALL_MATRIX_A_THRES_N_SYRK 128 -//#define BLIS_ENABLE_FAST_MATH - #endif diff --git a/config/zen4/CMakeLists.txt b/config/zen4/CMakeLists.txt new file mode 100644 index 0000000000..ea166b00c7 --- /dev/null +++ b/config/zen4/CMakeLists.txt @@ -0,0 +1,7 @@ +##Copyright (C) 2022, Advanced Micro Devices, Inc ## + +target_sources("${PROJECT_NAME}" + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/bli_cntx_init_zen4.c + ${CMAKE_CURRENT_SOURCE_DIR}/bli_family_zen4.h + ) diff --git a/config/zen4/bli_cntx_init_zen4.c b/config/zen4/bli_cntx_init_zen4.c new file mode 100644 index 0000000000..ac9875abf6 --- /dev/null +++ b/config/zen4/bli_cntx_init_zen4.c @@ -0,0 +1,358 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +/* + * List of default block sizes for zen4. + * Converted it to macro as this list is used at multiple places in this file. + */ + +#define BLI_CNTX_DEFAULT_BLKSZ_LIST(blkszs) \ + /* s d c z */ \ + bli_blksz_init_easy( &blkszs[ BLIS_MR ], 32, 16, 3, 3 ); \ + bli_blksz_init_easy( &blkszs[ BLIS_NR ], 12, 14, 8, 4 ); \ + bli_blksz_init_easy( &blkszs[ BLIS_MC ], 512, 240, 144, 18 ); \ + bli_blksz_init ( &blkszs[ BLIS_KC ], 480, 512, 256, 566, \ + 480, 320, 256, 566 ); \ + bli_blksz_init_easy( &blkszs[ BLIS_NC ], 6144, 4004, 4080, 256 ); \ + \ + bli_blksz_init_easy( &blkszs[ BLIS_AF ], 8, 8, -1, -1 ); \ + bli_blksz_init_easy( &blkszs[ BLIS_DF ], 8, 8, -1, -1 ); \ + + +void bli_cntx_init_zen4( cntx_t* cntx ) +{ + blksz_t blkszs[ BLIS_NUM_BLKSZS ]; + blksz_t thresh[ BLIS_NUM_THRESH ]; + // Set default kernel blocksizes and functions. + bli_cntx_init_zen4_ref( cntx ); + + // ------------------------------------------------------------------------- + + // Update the context with optimized native gemm micro-kernels and + // their storage preferences. + bli_cntx_set_l3_nat_ukrs + ( + 10, + // gemm + BLIS_GEMM_UKR, BLIS_FLOAT , bli_sgemm_skx_asm_32x12_l2, FALSE, + BLIS_GEMM_UKR, BLIS_DOUBLE, bli_dgemm_skx_asm_16x14, FALSE, + BLIS_GEMM_UKR, BLIS_SCOMPLEX, bli_cgemm_haswell_asm_3x8, TRUE, + BLIS_GEMM_UKR, BLIS_DCOMPLEX, bli_zgemm_haswell_asm_3x4, TRUE, + + BLIS_GEMM_AVX2_UKR, BLIS_FLOAT, bli_sgemm_haswell_asm_6x16, TRUE, + BLIS_GEMM_AVX2_UKR, BLIS_DOUBLE, bli_dgemm_haswell_asm_6x8, TRUE, + + // gemmtrsm_l + BLIS_GEMMTRSM_L_UKR, BLIS_FLOAT, bli_sgemmtrsm_l_haswell_asm_6x16, TRUE, + BLIS_GEMMTRSM_L_UKR, BLIS_DOUBLE, bli_dgemmtrsm_l_zen_asm_16x14, TRUE, + // gemmtrsm_u + BLIS_GEMMTRSM_U_UKR, BLIS_FLOAT, bli_sgemmtrsm_u_haswell_asm_6x16, TRUE, + BLIS_GEMMTRSM_U_UKR, BLIS_DOUBLE, bli_dgemmtrsm_u_zen_asm_16x14, TRUE, + + cntx + ); + + // Update the context with architecture specific threshold functions + bli_cntx_set_l3_thresh_funcs + ( + 2, + // GEMMT + BLIS_GEMMT, bli_cntx_gemmtsup_thresh_is_met_zen, + // SYRK + BLIS_SYRK, bli_cntx_syrksup_thresh_is_met_zen, + cntx + ); + + // packm kernels + bli_cntx_set_packm_kers + ( + 8, + BLIS_PACKM_6XK_KER, BLIS_FLOAT, bli_spackm_haswell_asm_6xk, + BLIS_PACKM_16XK_KER, BLIS_FLOAT, bli_spackm_haswell_asm_16xk, + BLIS_PACKM_6XK_KER, BLIS_DOUBLE, bli_dpackm_haswell_asm_6xk, + BLIS_PACKM_8XK_KER, BLIS_DOUBLE, bli_dpackm_haswell_asm_8xk, + BLIS_PACKM_3XK_KER, BLIS_SCOMPLEX, bli_cpackm_haswell_asm_3xk, + BLIS_PACKM_8XK_KER, BLIS_SCOMPLEX, bli_cpackm_haswell_asm_8xk, + BLIS_PACKM_3XK_KER, BLIS_DCOMPLEX, bli_zpackm_haswell_asm_3xk, + BLIS_PACKM_4XK_KER, BLIS_DCOMPLEX, bli_zpackm_haswell_asm_4xk, + cntx + ); + + // Update the context with optimized level-1f kernels. + bli_cntx_set_l1f_kers + ( + 9, + // axpyf + BLIS_AXPYF_KER, BLIS_FLOAT, bli_saxpyf_zen_int_5, + BLIS_AXPYF_KER, BLIS_DOUBLE, bli_daxpyf_zen_int_5, + BLIS_AXPYF_KER, BLIS_SCOMPLEX, bli_caxpyf_zen_int_5, + BLIS_AXPYF_KER, BLIS_DCOMPLEX, bli_zaxpyf_zen_int_5, + // dotxf + BLIS_DOTXF_KER, BLIS_FLOAT, bli_sdotxf_zen_int_8, + BLIS_DOTXF_KER, BLIS_DOUBLE, bli_ddotxf_zen_int_8, + BLIS_DOTXF_KER, BLIS_DCOMPLEX, bli_zdotxf_zen_int_6, + BLIS_DOTXF_KER, BLIS_SCOMPLEX, bli_cdotxf_zen_int_6, + // axpy2v + BLIS_AXPY2V_KER, BLIS_DOUBLE, bli_daxpy2v_zen_int, + cntx + ); + + // Update the context with optimized level-1v kernels. + bli_cntx_set_l1v_kers + ( + 24, + + // amaxv + BLIS_AMAXV_KER, BLIS_FLOAT, bli_samaxv_zen_int_avx512, + BLIS_AMAXV_KER, BLIS_DOUBLE, bli_damaxv_zen_int, + + // axpbyv + BLIS_AXPBYV_KER, BLIS_FLOAT, bli_saxpbyv_zen_int10, + BLIS_AXPBYV_KER, BLIS_DOUBLE, bli_daxpbyv_zen_int10, + BLIS_AXPBYV_KER, BLIS_SCOMPLEX, bli_caxpbyv_zen_int, + BLIS_AXPBYV_KER, BLIS_DCOMPLEX, bli_zaxpbyv_zen_int, + + // axpyv + BLIS_AXPYV_KER, BLIS_FLOAT, bli_saxpyv_zen_int10, + BLIS_AXPYV_KER, BLIS_DOUBLE, bli_daxpyv_zen_int10, + BLIS_AXPYV_KER, BLIS_SCOMPLEX, bli_caxpyv_zen_int5, + BLIS_AXPYV_KER, BLIS_DCOMPLEX, bli_zaxpyv_zen_int5, + + // dotv + BLIS_DOTV_KER, BLIS_FLOAT, bli_sdotv_zen_int10, + BLIS_DOTV_KER, BLIS_DOUBLE, bli_ddotv_zen_int10, + BLIS_DOTV_KER, BLIS_SCOMPLEX, bli_cdotv_zen_int5, + BLIS_DOTV_KER, BLIS_DCOMPLEX, bli_zdotv_zen_int5, + + // dotxv + BLIS_DOTXV_KER, BLIS_FLOAT, bli_sdotxv_zen_int, + BLIS_DOTXV_KER, BLIS_DOUBLE, bli_ddotxv_zen_int, + + // scalv + BLIS_SCALV_KER, BLIS_FLOAT, bli_sscalv_zen_int10, + BLIS_SCALV_KER, BLIS_DOUBLE, bli_dscalv_zen_int10, + + //swap + BLIS_SWAPV_KER, BLIS_FLOAT, bli_sswapv_zen_int8, + BLIS_SWAPV_KER, BLIS_DOUBLE, bli_dswapv_zen_int8, + + //copy + BLIS_COPYV_KER, BLIS_FLOAT, bli_scopyv_zen_int, + BLIS_COPYV_KER, BLIS_DOUBLE, bli_dcopyv_zen_int, + + //set + BLIS_SETV_KER, BLIS_FLOAT, bli_ssetv_zen_int, + BLIS_SETV_KER, BLIS_DOUBLE, bli_dsetv_zen_int, + cntx + ); + + // Initialize level-3 blocksize objects with architecture-specific values. + // + // These are reference block sizes and may be overridden based on + // number of threads used at runtime. + + BLI_CNTX_DEFAULT_BLKSZ_LIST(blkszs); + + // Update the context with the current architecture's register and cache + // blocksizes (and multiples) for native execution. + bli_cntx_set_blkszs + ( + BLIS_NAT, 7, + // level-3 + BLIS_NC, &blkszs[ BLIS_NC ], BLIS_NR, + BLIS_KC, &blkszs[ BLIS_KC ], BLIS_KR, + BLIS_MC, &blkszs[ BLIS_MC ], BLIS_MR, + BLIS_NR, &blkszs[ BLIS_NR ], BLIS_NR, + BLIS_MR, &blkszs[ BLIS_MR ], BLIS_MR, + // level-1f + BLIS_AF, &blkszs[ BLIS_AF ], BLIS_AF, + BLIS_DF, &blkszs[ BLIS_DF ], BLIS_DF, + cntx + ); + // ------------------------------------------------------------------------- + + // Initialize sup thresholds with architecture-appropriate values. s d c z + bli_blksz_init_easy( &thresh[ BLIS_MT ], 512, 256, 380, 110 ); + bli_blksz_init_easy( &thresh[ BLIS_NT ], 200, 256, 256, 128 ); + bli_blksz_init_easy( &thresh[ BLIS_KT ], 240, 220, 220, 110 ); + + // Initialize the context with the sup thresholds. + bli_cntx_set_l3_sup_thresh + ( + 3, + BLIS_MT, &thresh[ BLIS_MT ], + BLIS_NT, &thresh[ BLIS_NT ], + BLIS_KT, &thresh[ BLIS_KT ], + cntx + ); + + // Initialize the context with the sup handlers. + bli_cntx_set_l3_sup_handlers + ( + 2, + BLIS_GEMM, bli_gemmsup_ref, + BLIS_GEMMT, bli_gemmtsup_ref, + cntx + ); + + // Update the context with optimized small/unpacked gemm kernels. + bli_cntx_set_l3_sup_kers + ( + 28, + //BLIS_RCR, BLIS_DOUBLE, bli_dgemmsup_r_haswell_ref, + BLIS_RRR, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8m, TRUE, + BLIS_RRC, BLIS_DOUBLE, bli_dgemmsup_rd_haswell_asm_6x8m, TRUE, + BLIS_RCR, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8m, TRUE, + BLIS_RCC, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8n, TRUE, + BLIS_CRR, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8m, TRUE, + BLIS_CRC, BLIS_DOUBLE, bli_dgemmsup_rd_haswell_asm_6x8n, TRUE, + BLIS_CCR, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8n, TRUE, + BLIS_CCC, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8n, TRUE, + BLIS_RRR, BLIS_FLOAT, bli_sgemmsup_rv_zen_asm_6x16m, TRUE, + BLIS_RRC, BLIS_FLOAT, bli_sgemmsup_rd_zen_asm_6x16m, TRUE, + BLIS_RCR, BLIS_FLOAT, bli_sgemmsup_rv_zen_asm_6x16m, TRUE, + BLIS_RCC, BLIS_FLOAT, bli_sgemmsup_rv_zen_asm_6x16n, TRUE, + BLIS_CRR, BLIS_FLOAT, bli_sgemmsup_rv_zen_asm_6x16m, TRUE, + BLIS_CRC, BLIS_FLOAT, bli_sgemmsup_rd_zen_asm_6x16n, TRUE, + BLIS_CCR, BLIS_FLOAT, bli_sgemmsup_rv_zen_asm_6x16n, TRUE, + BLIS_CCC, BLIS_FLOAT, bli_sgemmsup_rv_zen_asm_6x16n, TRUE, + BLIS_RRR, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8m, TRUE, + BLIS_RCR, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8m, TRUE, + BLIS_CRR, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8m, TRUE, + BLIS_RCC, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8n, TRUE, + BLIS_CCR, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8n, TRUE, + BLIS_CCC, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8n, TRUE, + BLIS_RRR, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4m, TRUE, + BLIS_RCR, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4m, TRUE, + BLIS_CRR, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4m, TRUE, + BLIS_RCC, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4n, TRUE, + BLIS_CCR, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4n, TRUE, + BLIS_CCC, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4n, TRUE, + cntx + ); + + // Initialize level-3 sup blocksize objects with architecture-specific + // values. + // s d c z + bli_blksz_init ( &blkszs[ BLIS_MR ], 6, 6, 3, 3, + 9, 9, 3, 3 ); + bli_blksz_init_easy( &blkszs[ BLIS_NR ], 16, 8, 8, 4 ); + bli_blksz_init_easy( &blkszs[ BLIS_MC ], 144, 72, 72, 36 ); + bli_blksz_init_easy( &blkszs[ BLIS_KC ], 512, 256, 128, 64 ); + bli_blksz_init_easy( &blkszs[ BLIS_NC ], 8160, 4080, 2040, 1020 ); + + // Update the context with the current architecture's register and cache + // blocksizes for small/unpacked level-3 problems. + bli_cntx_set_l3_sup_blkszs + ( + 5, + BLIS_NC, &blkszs[ BLIS_NC ], + BLIS_KC, &blkszs[ BLIS_KC ], + BLIS_MC, &blkszs[ BLIS_MC ], + BLIS_NR, &blkszs[ BLIS_NR ], + BLIS_MR, &blkszs[ BLIS_MR ], + cntx + ); +} + +/* + * Override the block sizes in the context to the block sizes used + * by AVX2 GEMM+TRSM kernels, this is needed in Zen4 context as default + * GEMM kernels are AVX512 based and uses different block sizes. + * + * This function should be called in TRSM path before performing + * any packing operations. + * + * Also the context must be restored to default values by calling + * bli_zen4_restore_default_blkszs() before exiting TRSM Path + */ +void bli_zen4_override_trsm_blkszs (cntx_t* cntx) +{ + blksz_t blkszs[ BLIS_NUM_BLKSZS ]; + bli_blksz_init_easy( &blkszs[ BLIS_MR ], 6, 16, 3, 3 ); + bli_blksz_init_easy( &blkszs[ BLIS_NR ], 16, 14, 8, 4 ); + bli_blksz_init_easy( &blkszs[ BLIS_MC ], 144, 240, 144, 72 ); + bli_blksz_init_easy( &blkszs[ BLIS_KC ], 256, 512, 256, 256 ); + bli_blksz_init_easy( &blkszs[ BLIS_NC ], 4080, 4004, 4080, 4080 ); + + + // Update the context with the current architecture's register and cache + // blocksizes (and multiples) for native execution. + bli_cntx_set_blkszs + ( + BLIS_NAT, 5, + // level-3 + BLIS_NC, &blkszs[ BLIS_NC ], BLIS_NR, + BLIS_KC, &blkszs[ BLIS_KC ], BLIS_KR, + BLIS_MC, &blkszs[ BLIS_MC ], BLIS_MR, + BLIS_NR, &blkszs[ BLIS_NR ], BLIS_NR, + BLIS_MR, &blkszs[ BLIS_MR ], BLIS_MR, + cntx + ); +} + +/* + * Restore the block sizes to default values needed for zen4 context. + * + * This function should be called to restore the block sizes to there + * default values if they where overriden by calling + * bli_zen4_override_trsm_blkszs() to enable AVX2 GEMM kernels in the + * TRSM path. + * + */ +void bli_zen4_restore_default_blkszs (cntx_t* cntx) +{ + blksz_t blkszs[ BLIS_NUM_BLKSZS ]; + + BLI_CNTX_DEFAULT_BLKSZ_LIST(blkszs); + + // Update the context with the current architecture's register and cache + // blocksizes (and multiples) for native execution. + bli_cntx_set_blkszs + ( + BLIS_NAT, 7, + // level-3 + BLIS_NC, &blkszs[ BLIS_NC ], BLIS_NR, + BLIS_KC, &blkszs[ BLIS_KC ], BLIS_KR, + BLIS_MC, &blkszs[ BLIS_MC ], BLIS_MR, + BLIS_NR, &blkszs[ BLIS_NR ], BLIS_NR, + BLIS_MR, &blkszs[ BLIS_MR ], BLIS_MR, + // level-1f + BLIS_AF, &blkszs[ BLIS_AF ], BLIS_AF, + BLIS_DF, &blkszs[ BLIS_DF ], BLIS_DF, + cntx + ); +} diff --git a/config/zen4/bli_family_zen4.h b/config/zen4/bli_family_zen4.h new file mode 100644 index 0000000000..b21d1582f7 --- /dev/null +++ b/config/zen4/bli_family_zen4.h @@ -0,0 +1,87 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021-2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef BLI_FAMILY_ZEN4_ +#define BLI_FAMILY_ZEN4_ + +// By default, it is effective to parallelize the outer loops. +// Setting these macros to 1 will force JR and IR inner loops +// to be not paralleized. +// +#define BLIS_THREAD_MAX_IR 1 +#define BLIS_THREAD_MAX_JR 1 + +#define BLIS_ENABLE_SMALL_MATRIX +#define BLIS_ENABLE_SMALL_MATRIX_TRSM + +// This will select the threshold below which small matrix code will be called. +#define BLIS_SMALL_MATRIX_THRES 700 +#define BLIS_SMALL_M_RECT_MATRIX_THRES 160 +#define BLIS_SMALL_K_RECT_MATRIX_THRES 128 + +#define BLIS_SMALL_MATRIX_A_THRES_M_SYRK 96 +#define BLIS_SMALL_MATRIX_A_THRES_N_SYRK 128 + +// -- SIMD config -------------------------------------------------------- + +#define BLIS_SIMD_ALIGN_SIZE 64 + +#define BLIS_SIMD_SIZE 64 +#define BLIS_SIMD_NUM_REGISTERS 32 + +/* + * Override the block sizes in the context to the block sizes used + * by AVX2 GEMM+TRSM kernels, this is needed in Zen4 context as default + * GEMM kernels are AVX512 based and uses different block sizes. + * + * This function should be called in TRSM path before performing + * any packing operations. + * + * Also the context must be restored to default values by calling + * bli_zen4_restore_default_blkszs() before exiting TRSM Path + */ +BLIS_EXPORT_BLIS void bli_zen4_override_trsm_blkszs (cntx_t* cntx); + +/* + * Restore the block sizes to default values needed for zen4 context. + * + * This function should be called to restore the block sizes to there + * default values if they where overriden by calling + * bli_zen4_override_trsm_blkszs() to enable AVX2 GEMM kernels in the + * TRSM path. + * + */ +BLIS_EXPORT_BLIS void bli_zen4_restore_default_blkszs (cntx_t* cntx); + +#endif diff --git a/config/zen4/make_defs.mk b/config/zen4/make_defs.mk new file mode 100644 index 0000000000..062e680910 --- /dev/null +++ b/config/zen4/make_defs.mk @@ -0,0 +1,161 @@ +# +# +# BLIS +# An object-based framework for developing high-performance BLAS-like +# libraries. +# +# Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# - Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# - Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# - Neither the name(s) of the copyright holder(s) nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +# + +# FLAGS that are specific to the 'zen4' architecture are added here. +# FLAGS that are common for all the AMD architectures are present in +# config/zen/amd_config.mk. + +# Declare the name of the current configuration and add it to the +# running list of configurations included by common.mk. +THIS_CONFIG := zen4 +#CONFIGS_INCL += $(THIS_CONFIG) + +# +# --- Determine the C compiler and related flags --- +# + +# NOTE: The build system will append these variables with various +# general-purpose/configuration-agnostic flags in common.mk. You +# may specify additional flags here as needed. + +CPPROCFLAGS := +CMISCFLAGS := +CPICFLAGS := +CWARNFLAGS := + +ifneq ($(DEBUG_TYPE),off) +CDBGFLAGS := -g +endif + +ifeq ($(DEBUG_TYPE),noopt) +COPTFLAGS := -O0 +else +COPTFLAGS := -O3 +endif + +# Flags specific to optimized kernels. +# NOTE: The -fomit-frame-pointer option is needed for some kernels because +# they make explicit use of the rbp register. +CKOPTFLAGS := $(COPTFLAGS) -fomit-frame-pointer +ifeq ($(CC_VENDOR),gcc) +GCC_VERSION := $(strip $(shell $(CC) -dumpversion | cut -d. -f1)) + + +# gcc 11.0 or later: +ifeq ($(shell test $(GCC_VERSION) -ge 11; echo $$?),0) +# Update CKOPTFLAGS for gcc 11+ to use O3 optimization without +# -ftree-partial-pre flag. This flag results in suboptimal code +# generation for instrinsics based kernels. +ifneq ($(DEBUG_TYPE),noopt) +CKOPTFLAGS := -O2 -fgcse-after-reload -fipa-cp-clone -floop-interchange -floop-unroll-and-jam -fpeel-loops -fpredictive-commoning -fsplit-loops -fsplit-paths -ftree-loop-distribution -funswitch-loops -fvect-cost-model=dynamic -fversion-loops-for-strides -fomit-frame-pointer +endif + +CKVECFLAGS += -march=znver3 -mavx512f -mavx512dq -mavx512bw -mavx512vl -mavx512vnni -mavx512bf16 -mfpmath=sse +CRVECFLAGS += -march=znver3 +else +# gcc 9.0 or later: +ifeq ($(shell test $(GCC_VERSION) -ge 9; echo $$?),0) +CKVECFLAGS += -march=znver2 -mavx512f -mavx512dq -mavx512bw -mavx512vl -mavx512vnni -mfpmath=sse +CRVECFLAGS += -march=znver2 +else +ifeq ($(shell test $(GCC_VERSION) -ge 8; echo $$?),0) +CKVECFLAGS += -march=znver1 -mavx512f -mavx512dq -mavx512bw -mavx512vl -mavx512vnni -mfpmath=sse +CRVECFLAGS += -march=znver1 +else +# If gcc is older than 8.0.0 but at least 6.1.0, then we can use -march=znver1 +# as the fallback option. +CKVECFLAGS += -march=znver1 -mno-avx256-split-unaligned-store +CRVECFLAGS += -march=znver1 -mno-avx256-split-unaligned-store +endif # GCC 8 +endif # GCC 9 +endif # GCC 11 +else +ifeq ($(CC_VENDOR),clang) + +# AOCC clang has various formats for the version line + +# AOCC.LLVM.2.0.0.B191.2019_07_19 clang version 8.0.0 (CLANG: Jenkins AOCC_2_0_0-Build#191) (based on LLVM AOCC.LLVM.2.0.0.B191.2019_07_19) +# AOCC.LLVM.2.1.0.B1030.2019_11_12 clang version 9.0.0 (CLANG: Build#1030) (based on LLVM AOCC.LLVM.2.1.0.B1030.2019_11_12) +# AMD clang version 10.0.0 (CLANG: AOCC_2.2.0-Build#93 2020_06_25) (based on LLVM Mirror.Version.10.0.0) +# AMD clang version 11.0.0 (CLANG: AOCC_2.3.0-Build#85 2020_11_10) (based on LLVM Mirror.Version.11.0.0) +# AMD clang version 12.0.0 (CLANG: AOCC_3.0.0-Build#2 2020_11_05) (based on LLVM Mirror.Version.12.0.0) +# AMD clang version 14.0.0 (CLANG: AOCC_4.0.0-Build#98 2022_06_15) (based on LLVM Mirror.Version.14.0.0) + +# For our prupose we just want to know if it version 2x or 3x or 4x + +# for version 4x we will enable znver4 +ifeq ($(strip $(shell $(CC) -v |&head -1 |grep -c 'AOCC_4')),1) +CKVECFLAGS += -march=znver4 -mavx512f -mavx512dq -mavx512bw -mavx512vl -mavx512bf16 -mfpmath=sse +CRVECFLAGS += -march=znver4 +else +# for version 3x we will enable znver3 +ifeq ($(strip $(shell $(CC) -v |&head -1 |grep -c 'AOCC_3')),1) +CKVECFLAGS += -march=znver3 -mavx512f -mavx512dq -mavx512bw -mavx512vl -mavx512vnni -mavx512bf16 -mfpmath=sse +CRVECFLAGS += -march=znver3 +else +# for version 2x we will enable znver2 +ifeq ($(strip $(shell $(CC) -v |&head -1 |grep -c 'AOCC.LLVM.2\|AOCC_2')),1) +CKVECFLAGS += -march=znver2 -mavx512f -mavx512dq -mavx512bw -mavx512vl -mavx512vnni -mfpmath=sse +CRVECFLAGS += -march=znver2 +else +#if compiling with clang +VENDOR_STRING := $(strip $(shell ${CC_VENDOR} --version | egrep -o '[0-9]+\.[0-9]+\.?[0-9]*')) +CC_MAJOR := $(shell (echo ${VENDOR_STRING} | cut -d. -f1)) +#clang 9.0 or later: +ifeq ($(shell test $(CC_MAJOR) -ge 9; echo $$?),0) +CKVECFLAGS += -march=znver2 +CRVECFLAGS += -march=znver2 +else +CKVECFLAGS += -march=znver1 +CRVECFLAGS += -march=znver1 +endif # ge 9 +endif # aocc 2 +endif # aocc 3 +endif # aocc 4 +endif # clang +endif # gcc + +# Flags specific to reference kernels. +CROPTFLAGS := $(CKOPTFLAGS) + +# Flags specific to reference kernels. +# Note: We use AVX2 for reference kernels because, as Jeff Hammond says, +# reference kernel code "is not going to achieve high enough SIMD utilization +# to overcome the AVX-512 frequency drop". (Issue #187) +CRVECFLAGS += -mno-avx512f -mno-avx512vl -mno-avx512bw -mno-avx512dq -mno-avx512cd -funsafe-math-optimizations -ffp-contract=fast + +# Store all of the variables here to new variables containing the +# configuration name. +$(eval $(call store-make-defs,$(THIS_CONFIG))) + diff --git a/config_registry b/config_registry index 558eccc30c..4e6716dfa1 100644 --- a/config_registry +++ b/config_registry @@ -11,7 +11,7 @@ x86_64: intel64 amd64 amd64_legacy intel64: skx knl haswell sandybridge penryn generic amd64_legacy: excavator steamroller piledriver bulldozer generic -amdzen: zen3 zen2 zen generic +amdzen: zen4 zen3 zen2 zen generic # NOTE: ARM families will remain disabled until runtime hardware detection # logic is added to BLIS. @@ -26,6 +26,7 @@ sandybridge: sandybridge penryn: penryn # AMD architectures. +zen4: zen4/zen4/skx/zen3/zen2/zen/haswell zen3: zen3/zen3/zen2/zen/haswell zen2: zen2/zen2/zen/haswell zen: zen/zen/haswell diff --git a/configure b/configure index f49ea19e5e..73dc8cc358 100755 --- a/configure +++ b/configure @@ -264,6 +264,15 @@ print_usage() echo " \"small\" depends on thresholds that may vary by sub-" echo " configuration." echo " " + echo " -a NAME --enable-addon=NAME" + echo " " + echo " Enable the code provided by an addon. An addon consists" + echo " of a separate directory of code that provides additional" + echo " APIs, implementations, and/or operations that would" + echo " otherwise not be present within a build of BLIS. This" + echo " option may be used multiple times to specify the inclusion" + echo " of multiple addons. By default, no addons are enabled." + echo " " echo " -s NAME --enable-sandbox=NAME" echo " " echo " Enable a separate sandbox implementation of gemm. This" @@ -344,6 +353,19 @@ print_usage() echo " Num_threads is derived from either environment variable" echo " OMP_NUM_THREADS or BLIS_NUM_THREADS' or bli_set_num_threads() API." echo " " + echo " --enable-blis-arch-type, --disable-blis-arch-type" + echo " " + echo " Disable (Enabled by default) support for BLIS_ARCH_TYPE" + echo " environment variable, which allows user to select" + echo " architecture-specific code path at runtime." + echo " If disabled, in builds with multiple code paths, BLIS" + echo " will still select path automatically." + echo " " + echo " --rename-blis-arch-type=STRING" + echo " " + echo " Change environment variable used to select architecture-specific" + echo " code path from BLIS_ARCH_TYPE to STRING" + echo " " echo " -q, --quiet Suppress informational output. By default, configure" echo " is verbose. (NOTE: -q is not yet implemented)" echo " " @@ -940,6 +962,18 @@ canonicalize_ws() echo "${str}" } +rm_duplicate_words_simple() +{ + local str revstr revres res + + str="$1" + + # Remote duplicates, keeping the first occurrence. + res=$(echo "${str}" | awk '{for (i=1;i<=NF;i++) if (!a[$i]++) printf("%s%s",$i,FS)}{printf("\n")}') + + echo "${res}" +} + rm_duplicate_words() { local str revstr revres res @@ -1124,8 +1158,11 @@ auto_detect() # NOTE: -D_GNU_SOURCE is needed to enable POSIX extensions to # pthreads (i.e., barriers). + double_quote_open=\"\\\" + double_quote_close=\\\"\" cmd="${cc} ${config_defines} \ -DBLIS_CONFIGURETIME_CPUID \ + -D__blis_arch_type_name=${double_quote_open}${rename_blis_arch_type}${double_quote_close} \ ${c_hdr_paths} \ -std=c99 -D_GNU_SOURCE \ ${cflags} \ @@ -1798,67 +1835,13 @@ try_assemble() set_default_version() { - local gitdir version_file gd_stderr git_describe_str git_error new_version_str - - gitdir='.git' - # The path to the version file. version_file=$1 echo "${script_name}: determining default version string." - # Check if the .git dir exists; if it does not, we do nothing. - if [ -d "${dist_path}/${gitdir}" ]; then - - echo "${script_name}: found '${gitdir}' directory; assuming git clone." - - echo "${script_name}: executing: git describe --tags." - - gd_stderr="git_describe_stderr.txt" - - # Query git for the version string, which is simply the current tag, - # followed by a number signifying how many commits have transpired - # since the tag, followed by a 'g' and a shortened hash tab. Capture - # stderr to a file. - git_describe_str=$(git -C ${dist_path} describe --tags 2> ${gd_stderr}) - - # Pull in whatever error message was generated, if any, and delete - # the file. - git_error=$(cat ${gd_stderr}) - - # Remove the stderr file. - rm -f ${gd_stderr} - - # If git returned an error, don't do anything. - if [ -n "${git_error}" ]; then - - echo "${script_name}: git returned an error: '${git_error}'." - echo "${script_name}: using string from unmodified version file." - - # Use what's in the version file as-is. - version="AOCL BLIS $(cat "${version_file}")" - else - - echo "${script_name}: got back ${git_describe_str}." - - # Strip off the commit hash label. - new_version_str=$(echo ${git_describe_str} | cut -d- -f2) - - echo "${script_name}: truncating to ${new_version_str}." - - # Write the new version string to the version file. - #echo "${new_version_str}" > ${version_file} - - # Set the version variable. - version="AOCL BLIS ${new_version_str}" - fi - else - - echo "${script_name}: could not find '${gitdir}' directory; using unmodified version file." - - # Use what's in the version file as-is. - version="AOCL BLIS $(cat "${version_file}")" - fi + # Use what's in the version file as-is. + version="AOCL-BLIS $(cat "${version_file}") Build $(date +%Y%m%d)" } @@ -1915,6 +1898,13 @@ main() bli_config_h_in_path="${build_dirpath}/${bli_config_h_in}" bli_config_h_out_path="${cur_dirpath}/${bli_config_h_out}" + # The names/paths for the template bli_addon.h.in and its instantiated + # counterpart. + bli_addon_h_in='bli_addon.h.in' + bli_addon_h_out='bli_addon.h' + bli_addon_h_in_path="${build_dirpath}/${bli_addon_h_in}" + bli_addon_h_out_path="${cur_dirpath}/${bli_addon_h_out}" + # Path to 'mirror-tree.sh' script. mirror_tree_sh="${build_dirpath}/mirror-tree.sh" @@ -1941,6 +1931,9 @@ main() # The root directory of the BLIS framework. aocldtl_dir='aocl_dtl' aocldtl_dirpath="${dist_path}/${aocldtl_dir}" + # The names of the addons. + addon_dir='addon' + addon_dirpath="${dist_path}/${addon_dir}" # The name of the sandbox directory. sandbox_dir='sandbox' @@ -2048,6 +2041,12 @@ main() enable_aocl_dynamic='yes' force_version='no' complex_return='default' + disable_blis_arch_type='no' + rename_blis_arch_type='BLIS_ARCH_TYPE' + + # The addon flag and names. + addon_flag='' + addon_list='' # The sandbox flag and name. sandbox_flag='' @@ -2093,7 +2092,7 @@ main() # Process our command line options. unset OPTIND - while getopts ":hp:d:e:s:t:r:qci:b:-:" opt; do + while getopts ":hp:d:e:a:s:t:r:qci:b:-:" opt; do case $opt in -) case "$OPTARG" in @@ -2194,12 +2193,21 @@ main() disable-mem-tracing) enable_mem_tracing='no' ;; + enable-addon=*) + addon_flag=1 + addon_name=${OPTARG#*=} + # Append the addon name to the list. + addon_list="${addon_list} ${addon_name}" + ;; + disable-addon) + addon_flag='' + ;; enable-sandbox=*) sandbox_flag=1 sandbox=${OPTARG#*=} ;; disable-sandbox) - sandbox_flag=0 + sandbox_flag='' ;; int-size=*) int_type_size=${OPTARG#*=} @@ -2264,6 +2272,15 @@ main() complex-return=*) complex_return=${OPTARG#*=} ;; + enable-blis-arch-type) + disable_blis_arch_type='no' + ;; + disable-blis-arch-type) + disable_blis_arch_type='yes' + ;; + rename-blis-arch-type=*) + rename_blis_arch_type=${OPTARG#*=} + ;; *) print_usage ;; @@ -2282,6 +2299,12 @@ main() e) export_shared=$OPTARG ;; + a) + addon_flag=1 + addon_name=$OPTARG + # Append the addon name to the list. + addon_list="${addon_list} ${addon_name}" + ;; s) sandbox_flag=1 sandbox=$OPTARG @@ -3141,6 +3164,34 @@ main() exit 1 fi + # Check if addons were given. + if [ -n "${addon_flag}" ]; then + + # Remove duplicates in the addon list, if they exist. + addon_list=$(rm_duplicate_words_simple "${addon_list}") + + echo "${script_name}: configuring with addons:" + + for addon in ${addon_list}; do + + echo "${script_name}: ${addon_dir}/${addon}" + + addon_fullpath="${addon_dirpath}/${addon}" + + if [ ! -d "${addon_fullpath}" ]; then + echo "${script_name}: requested addon sub-directory does not exist! Cannot continue." + echo "${script_name}: *** Please verify addon existence and name." + exit 1 + fi + done + + enable_addons_01=1 + else + echo "${script_name}: configuring with no addons." + + enable_addons_01=0 + fi + # Check if a sandbox was given. if [ -n "${sandbox_flag}" ]; then @@ -3205,6 +3256,17 @@ main() exit 1 fi + if [ "x${disable_blis_arch_type}" = "xyes" ]; then + echo "${script_name}: user selection of code path using BLIS_ARCH_TYPE env var is disabled." + disable_blis_arch_type_01='1' + else + disable_blis_arch_type_01='0' + fi + + # Check if the user requested a custom env var name to replace BLIS_ARCH_TYPE. + if [ "x${rename_blis_arch_type}" != "xBLIS_ARCH_TYPE" ]; then + echo "${script_name}: configuring with BLIS_ARCH_TYPE env var renamed to '${rename_blis_arch_type}'." + fi echo "${script_name}: configuring complex return type as \"${complex_return}\"." @@ -3292,6 +3354,15 @@ main() kernel_list_defines="${kernel_list_defines}#define ${kernel_define}\n" done + # Create a list of #includes, one for each addon in addon_list. + addon_list_includes="" + for addon in ${addon_list}; do + + # Create a #define and add it to the running list. + addon_header="\"${addon}.h\"" + addon_list_includes="${addon_list_includes}#include ${addon_header}\n" + done + # -- Determine whether we are performing an out-of-tree build -------------- @@ -3319,7 +3390,7 @@ main() fi - # -- Instantiate config.mk, bli_config.h files from templates -------------- + # -- Instantiate config.mk file from template ------------------------------ # Begin substituting information into the config_mk_in file, outputting # to config_mk_out. @@ -3365,6 +3436,7 @@ main() | sed -e "s/@enable_cblas@/${enable_cblas}/g" \ | sed -e "s/@enable_memkind@/${enable_memkind}/g" \ | sed -e "s/@pragma_omp_simd@/${pragma_omp_simd}/g" \ + | sed -e "s/@addon_list@/${addon_list}/g" \ | sed -e "s/@sandbox@/${sandbox}/g" \ | sed -e "s/@enable_trsm_preinversion@/${enable_trsm_preinversion}/g" \ | sed -e "s/@enable_aocl_dynamic@/${enable_aocl_dynamic}/g" \ @@ -3373,6 +3445,7 @@ main() | sed -e "s/\@enable_aocl_zen\@/${enable_aocl_zen}/g" \ > "${config_mk_out_path}" + # -- Instantiate bli_config.h file from template --------------------------- # Begin substituting information into the bli_config_h_in file, outputting # to bli_config_h_out. NOTE: We use perl instead of sed because the version @@ -3407,8 +3480,21 @@ main() | sed -e "s/@enable_sandbox@/${enable_sandbox_01}/g" \ | sed -e "s/@enable_shared@/${enable_shared_01}/g" \ | sed -e "s/@complex_return_intel@/${complex_return_intel01}/g" \ + | sed -e "s/@disable_blis_arch_type@/${disable_blis_arch_type_01}/g" \ + | sed -e "s/@rename_blis_arch_type@/${rename_blis_arch_type}/g" \ > "${bli_config_h_out_path}" + # -- Instantiate bli_addon.h file from template ---------------------------- + + # Begin substituting information into the bli_addon_h_in file, outputting + # to bli_addon_h_out. NOTE: We use perl instead of sed because the version + # of sed used on OS X is old and does not handle the '\n' character + # intuitively, which was used when constructing ${addon_list_includes}. + echo "${script_name}: creating ${bli_addon_h_out_path} from ${bli_addon_h_in_path}" + cat "${bli_addon_h_in_path}" \ + | perl -pe "s/\@addon_list_includes\@/${addon_list_includes}/g" \ + | sed -e "s/@enable_addons@/${enable_addons_01}/g" \ + > "${bli_addon_h_out_path}" # -- Create top-level object directories ----------------------------------- @@ -3421,7 +3507,6 @@ main() obj_config_dirpath="${base_obj_dirpath}/${config_dir}" - #echo "${script_name}: creating ${obj_config_dirpath}" mkdir -p ${obj_config_dirpath} for conf in ${config_list}; do echo "${script_name}: creating ${obj_config_dirpath}/${conf}" @@ -3431,7 +3516,6 @@ main() obj_kernels_dirpath="${base_obj_dirpath}/${kernels_dir}" - #echo "${script_name}: creating ${obj_kernels_dirpath}" mkdir -p ${obj_kernels_dirpath} for kern in ${kernel_list}; do echo "${script_name}: creating ${obj_kernels_dirpath}/${kern}" @@ -3441,7 +3525,6 @@ main() obj_refkern_dirpath="${base_obj_dirpath}/${refkern_dir}" - #echo "${script_name}: creating ${obj_refkern_dirpath}" mkdir -p ${obj_refkern_dirpath} for conf in ${config_list}; do echo "${script_name}: creating ${obj_refkern_dirpath}/${conf}" @@ -3460,6 +3543,18 @@ main() echo "${script_name}: creating ${obj_frame_dirpath}" mkdir -p ${obj_frame_dirpath} + + if [ -n "${addon_flag}" ]; then + + obj_addon_dirpath="${base_obj_dirpath}/${addon_dir}" + + for addon in ${addon_list}; do + echo "${script_name}: creating ${obj_addon_dirpath}/${addon}" + mkdir -p ${obj_addon_dirpath}/${addon} + done + fi + + if [ -n "${sandbox_flag}" ]; then obj_sandbox_dirpath="${base_obj_dirpath}/${sandbox_dir}" @@ -3487,6 +3582,7 @@ main() echo "${script_name}: creating ${base_lib_dirpath}" mkdir -p ${base_lib_dirpath} + # Create include directory (if it does not already exist). base_include_dirpath="${include_dirpath}/${config_name}" @@ -3545,6 +3641,16 @@ main() echo "${script_name}: mirroring ${aocldtl_dirpath} to ${obj_aocldtl_dirpath}" ${mirror_tree_sh} ${aocldtl_dirpath} ${obj_aocldtl_dirpath} + # Mirror the chosen addon source tree to its object sub-directory. + if [ -n "${addon_flag}" ]; then + + for addon in ${addon_list}; do + + echo "${script_name}: mirroring ${addon_dirpath}/${addon} to ${obj_addon_dirpath}/${addon}" + ${mirror_tree_sh} "${addon_dirpath}/${addon}" "${obj_addon_dirpath}/${addon}" + done + fi + # Mirror the chosen sandbox source tree to its object sub-directory. if [ -n "${sandbox_flag}" ]; then @@ -3643,6 +3749,25 @@ main() ${gen_make_frags_dirpath}/suffix_list \ ${gen_make_frags_dirpath}/ignore_list + # Generate makefile fragments in the addon sub-directory. + if [ -n "${addon_flag}" ]; then + + for addon in ${addon_list}; do + + echo "${script_name}: creating makefile fragments in ${obj_addon_dirpath}/${addon}" + ${gen_make_frags_sh} \ + -h -r -v0 \ + -o ${script_name} \ + -p 'ADDON' \ + ${addon_dirpath}/${addon} \ + ${obj_addon_dirpath}/${addon} \ + ${gen_make_frags_dirpath}/fragment.mk \ + ${gen_make_frags_dirpath}/suffix_list \ + ${gen_make_frags_dirpath}/ignore_list + done + fi + + # Generate makefile fragments in the sandbox sub-directory. if [ -n "${sandbox_flag}" ]; then diff --git a/docs/Addons.md b/docs/Addons.md new file mode 100644 index 0000000000..595cebfa4b --- /dev/null +++ b/docs/Addons.md @@ -0,0 +1,231 @@ +## Contents + +* **[Introduction](Addons.md#introduction)** +* **[Enabling addons](Addons.md#enabling-addons)** +* **[Addon rules](Addons.md#addon-rules)** +* **[Caveats](Addons.md#caveats)** +* **[Known issues](Addons.md#known-issues)** +* **[Conclusion](Addons.md#conclusion)** + + +## Introduction + +This file briefly describes the requirements for building a custom BLIS +*addon*. + +Simply put, an addon in BLIS provides additional APIs, operations, and/or +implementations that may be useful to certain users. An addon can be +thought of as a standalone extension of BLIS that does not depend on any +other addon, although addons may utilize existing functionality or kernels +within the core framework. + +By definition, an addon should *never* provide APIs that conflict with +the interfaces that belong to either the [typed API](BLISTypedAPI.md) or the +[object API](BLISObjectAPI.md). Thus, you'll never have to worry about a +properly constructed (and properly functioning) addon interfering with or +otherwise changing core BLIS functionality. + +How does an addon differ from a [sandbox](Sandboxes.md)? Great question! +Sometimes you want to include additional BLIS-like functionality that does +not relate directly to `gemm` or any other BLIS operation. +(By contrast, a sandbox requires you to implement `gemm` whether you want +to or not.) +Furthermore, you may wish to enable multiple addons simultaneously. +(By contrast, only one sandbox may be enabled at a time.) +Thus, the addon feature provides additional flexibility to some +users in a way that sandboxes cannot, while still providing many of the +conveniences of sandboxes. + +## Enabling an addon + +To enable an existing addon at configure-time, you simply specify it as an +option to `configure`. Either of the following usages are accepted: +``` +$ ./configure --enable-addon=foobar auto +$ ./configure -a foobar auto +``` +Here, we tell `configure` that we want to use the `foobar` addon, which +corresponds to a subdirectory of the `addon` directory named `foobar`. +(Reminder: the `auto` argument is the configuration target and +unrelated to addons.) + +You may also enable multiple addons within the same build of BLIS: +``` +$ ./configure -a foobar -a thing1 -a thing2 auto +``` +Note that the default behavior of `configure` is that no addons are enabled. + +As `configure` runs, you should get output that includes lines +similar to: +``` +configure: configuring with addons: +configure: addon/foobar +configure: addon/thing1 +configure: addon/thing2 +``` +And when you build BLIS, the addon source code will be among the last files to +be compiled: +``` +Compiling obj/haswell/addon/foobar/foobar.o ('haswell' CFLAGS for addons) +Compiling obj/haswell/addon/thing1/thing1.o ('haswell' CFLAGS for addons) +Compiling obj/haswell/addon/thing1/thing1_api.o ('haswell' CFLAGS for addons) +Compiling obj/haswell/addon/thing2/thing2_api.o ('haswell' CFLAGS for addons) +... +``` +That's it! After the BLIS library is built, it will contain your chosen +addons. You can always confirm this by using `nm` to confirm the presence +of your API symbols: +``` +$ nm lib/haswell/libblis.a | grep foobar +foobar.o: +0000000000000000 T foobar +``` + +## Addon rules + +Please follow these guidelines for the best developer experience when +creating addons. + +1. As with sandboxes, you don't need to worry about creating makefiles. The +BLIS build system will take care of this for you. :) By configuring BLIS with +an addon enabled, `make` will scan your addon subdirectory and compile +all of its source code using similar compilation rules as were used for the rest +of the framework. In addition, the compilation command line will automatically +contain one `-I` option for every subdirectory in your addon, +so it doesn't matter where in your addon directory hierarchy you place your +header files -- they will be found! + +2. We recommend that you write your addon in C99. While you *may* use C++11 +to implement your addon, you should provide a C99 wrapper API to your +implementation so that others can interface with it. There is no guarantee +that the end-user will be using a C++11 compiler, and therefore you should +limit the definitions in your addon header to those that are C99 compliant. +If you write your addon in C++11, you must use one of the BLIS-approved file +extensions for your source files (`.cc`, `.cpp`, `.cxx`) and your local +header files (`.hh`, `.hpp`, `.hxx`). +Note that `blis.h` already contains all of its definitions inside of an +`extern "C"` block, so you should be able to `#include "blis.h"` from your +C++11 source code without any issues. + +3. All of your code related to the addon should reside within the named +addon directory, or some subdirectory therein. If your addon requires +new kernels, you should add kernel source code to an appropriate +microarchitecture-specific subdirectory within the top-level `kernels` +directory so that they are compiled with the correct +microarchitecture-specific optimization flags. + +4. If your addon is named `foobar`, the BLIS build system will expect to +find a header called `foobar.h` somewhere in the `addon/foobar` directory +(or one of its subdirectories). This `foobar.h` header will automatically +be inlined into the monolithic `blis.h` header that is produced by the +BLIS build system. `foobar.h` may `#include` other local headers, each of +which will also (recursively) get inlined into `blis.h`. However, you may +choose to omit some local addon headers from `foobar.h.` You might do this, +for example, because those headers define things that are not needed in +order for the end user to call your addon code. + +5. Your addon APIs will always be available within static library builds of +BLIS, but if you want your addon APIs to be exported as public APIs within +*shared* library builds of BLIS, you'll need to annotate the prototypes +accordingly. (BLIS makes its shared library symbols private by default; this +allows us to export only those functions that we consider to be part of the +public APIs.) This annotation can be done by prefixing function prototypes +with the `BLIS_EXPORT_ADDON` macro as follows: +```c +BLIS_EXPORT_ADDON void foobar_calc( void* a, void* b ); +``` + +6. Do not define any symbols in your addon that conflict with any symbols within +the core framework. For example, don't define a function called `bli_copym()` +in your addon since that function is already defined within BLIS. + +7. Do not define any symbols in your addon that conflict with any symbols within +the C99 standard libraries/headers. For example, don't define a function called +`printf()` since that function is already defined within the C99 standard library. + +8. *Try* to not define any symbols in your addon that conflict with symbols in any +other addon, unless your addon is meant to serve as an alternative to the +conflicting addon, in which case conflicting symbol names is okay (since you +will presumably never build with both addons enabled). + +9. When choosing names for your addon files, avoid source filenames that already +exist within BLIS. For example, don't name one of your files `bli_obj.c` +since that file would compile into `bli_obj.o`, which will have already been +placed into the library by the build system. + +10. Similarly, avoid header filenames that already exist within BLIS or C99. +For example, don't name one of your header files `bli_obj.h` since that file +already exists in BLIS. Also, don't name one of your header files `math.h` +since that name would conflict with the `math.h` defined by C99. (This also +means you shouldn't name your addon `math` since normally that name would +require that you provide a `math.h` header inside the addon directory.) + +If you follow these rules, you will be much more likely to have a pleasant +experience integrating your BLIS addon into the larger framework. + +## Caveats + +Notice that the BLIS addons are limited in what they can accomplish. Generally +speaking, addons cannot change existing implementations within BLIS. Instead, +addons aim to provide a way to quickly augment BLIS with additional bundles of +code that extend BLIS's set of functionality in some interesting way. If you +want to define new BLAS-like functions, but don't know where to start, creating +a new addon is an appropriate place to start experimenting. If you want to +change or refactor existing BLIS code, an addon is probably not suited for your +needs. + +Another important limitation is the fact that the build system currently uses +"framework `CFLAGS`" when compiling the addon source files. These are the same +`CFLAGS` used when compiling general framework source code, +``` +# Example framework CFLAGS used by 'haswell' sub-configuration +-O2 -Wall -Wno-unused-function -Wfatal-errors -fPIC -std=c99 +-D_POSIX_C_SOURCE=200112L -Iinclude/haswell -I./frame/3/ +-I./frame/1m/ -I./frame/1f/ -I./frame/1/ -I./frame/include +-DBLIS_VERSION_STRING=\"0.8.1-195\" -fvisibility=hidden +``` +which are likely more general-purpose than the `CFLAGS` used for, say, +optimized kernels or even reference kernels: +``` +# Example optimized kernel CFLAGS used by 'haswell' sub-configuration +-O3 -fomit-frame-pointer -mavx2 -mfma -mfpmath=sse -march=haswell -Wall +-Wno-unused-function -Wfatal-errors -fPIC -std=c99 -D_POSIX_C_SOURCE=200112L +-Iinclude/haswell -I./frame/3/ -I./frame/1m/ -I./frame/1f/ -I./frame/1/ +-I./frame/include -DBLIS_VERSION_STRING=\"0.8.1-195\" -fvisibility=hidden +``` +(To see precisely which flags are being employed for any given file, enable +verbosity at compile-time via `make V=1`.) Compiling addons with these more +versatile `CFLAGS` compiler options means that we only need to compile one +instance of each addon source file, even when targeting multiple +configurations (for example, via `./configure x86_64`). However, it also means +that addons are not ideal for microkernels, as they sometimes need additional +compiler flags in order to +yield the highest performance. If you have a new microkernel you would like to +use within an addon, you can always develop it within that addon. However, +once it is stable and ready for use by others, it's best to move the kernel(s) +to the appropriate microarchitecture-specific subdirectory of the `kernels` +directory the kernel(s). This will allow the kernel to be compiled with the +appropriate microarchitecture-specific compiler flags. +Please see the +[Configuration Guide](ConfigurationHowTo) +for more details, and when in doubt, please don't be shy about seeking +guidance from BLIS developers by opening a +[new issue](https://github.com/flame/blis/issues) or sending a message to the +[blis-devel](http://groups.google.com/d/forum/blis-devel) mailing list. + +Notwithstanding these limitations, hopefully you still find BLIS addons +useful! + +## Known issues + +* None yet. + +## Conclusion + +If you encounter any problems, please open +a new [issue on GitHub](https://github.com/flame/blis/issues). + +If you are unsure about how something works, you can still open an issue. Or, you +can send a message to +[blis-devel](https://groups.google.com/d/forum/blis-devel) mailing list. + diff --git a/frame/2/bli_l2_ker_prot.h b/frame/2/bli_l2_ker_prot.h index 82febd761f..5182b5d670 100644 --- a/frame/2/bli_l2_ker_prot.h +++ b/frame/2/bli_l2_ker_prot.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020-21, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020-22, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -54,3 +54,16 @@ void PASTEMAC(ch,opname) \ cntx_t* restrict cntx \ ); +#define HER_KER_PROT( ctype, ch, opname ) \ +\ +void PASTEMAC(ch,opname) \ + ( \ + uplo_t uplo, \ + conj_t conjx, \ + conj_t conjh, \ + dim_t m, \ + ctype* restrict alpha, \ + ctype* restrict x, inc_t incx, \ + ctype* restrict c, inc_t rs_c, inc_t cs_c, \ + cntx_t* restrict cntx \ + ); \ No newline at end of file diff --git a/frame/2/gemv/CMakeLists.txt b/frame/2/gemv/CMakeLists.txt index 2f75a00f63..633ec9431a 100644 --- a/frame/2/gemv/CMakeLists.txt +++ b/frame/2/gemv/CMakeLists.txt @@ -12,6 +12,7 @@ target_sources("${PROJECT_NAME}" if(${TARGET_ARCH} STREQUAL zen OR ${TARGET_ARCH} STREQUAL zen2 OR ${TARGET_ARCH} STREQUAL zen3 OR + ${TARGET_ARCH} STREQUAL zen4 OR ${TARGET_ARCH} STREQUAL amdzen) target_sources("${PROJECT_NAME}" PRIVATE diff --git a/frame/2/gemv/bli_gemv_unf_var1_amd.c b/frame/2/gemv/bli_gemv_unf_var1_amd.c index 447f8dbc43..a9534bd9a0 100644 --- a/frame/2/gemv/bli_gemv_unf_var1_amd.c +++ b/frame/2/gemv/bli_gemv_unf_var1_amd.c @@ -343,6 +343,14 @@ void bli_sgemv_var1_smart_threading // Calculate the amount data processed per iteration dim_t n_per_loop = n / fuse; double data_per_iter = n_per_loop* m; + + // Exception handling when m-dimenstion or n-dimension is zero + if (bli_zero_dim2(m,n)) + { + *nt = 1; + return; + } + double m_n_ratio = m/n; // When the input value is less than the fuse factor @@ -511,6 +519,7 @@ void bli_sgemv_unf_var1 if ( ( nt_max > 1 ) & ( is_omp_mt_enabled == TRUE ) ) { +#ifdef BLIS_ENABLE_OPENMP b_fuse = 4; //Setting the thread count to the maximum number of threads provided @@ -536,6 +545,7 @@ void bli_sgemv_unf_var1 cntx, nt ); +#endif// BLIS_ENABLE_OPENMP } else { diff --git a/frame/2/hemv/CMakeLists.txt b/frame/2/hemv/CMakeLists.txt index 34820c3762..10e324b52d 100644 --- a/frame/2/hemv/CMakeLists.txt +++ b/frame/2/hemv/CMakeLists.txt @@ -14,7 +14,8 @@ target_sources("${PROJECT_NAME}" # Select AMD specific sources for AMD configurations. if(${TARGET_ARCH} STREQUAL zen OR ${TARGET_ARCH} STREQUAL zen2 OR - ${TARGET_ARCH} STREQUAL zen3 OR + ${TARGET_ARCH} STREQUAL zen3 OR + ${TARGET_ARCH} STREQUAL zen4 OR ${TARGET_ARCH} STREQUAL amdzen) target_sources("${PROJECT_NAME}" PRIVATE diff --git a/frame/2/her/CMakeLists.txt b/frame/2/her/CMakeLists.txt index 37b06d2a7f..b97ee3874b 100644 --- a/frame/2/her/CMakeLists.txt +++ b/frame/2/her/CMakeLists.txt @@ -1,9 +1,25 @@ -##Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved.## +##Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved.## target_sources("${PROJECT_NAME}" PRIVATE - ${CMAKE_CURRENT_SOURCE_DIR}/bli_her_unb_var1.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_her_unb_var2.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_her_var_oapi.c ) +# Select AMD specific sources for AMD configurations. +if(${TARGET_ARCH} STREQUAL zen OR +${TARGET_ARCH} STREQUAL zen2 OR +${TARGET_ARCH} STREQUAL zen3 OR +${TARGET_ARCH} STREQUAL zen4 OR +${TARGET_ARCH} STREQUAL amdzen) + target_sources("${PROJECT_NAME}" + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/bli_her_unb_var1_amd.c + ${CMAKE_CURRENT_SOURCE_DIR}/bli_her_unb_var2_amd.c + ) +else() + target_sources("${PROJECT_NAME}" + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/bli_her_unb_var1.c + ${CMAKE_CURRENT_SOURCE_DIR}/bli_her_unb_var2.c + ) +endif() \ No newline at end of file diff --git a/frame/2/her/bli_her_unb_var1_amd.c b/frame/2/her/bli_her_unb_var1_amd.c new file mode 100644 index 0000000000..1dcb6d0eeb --- /dev/null +++ b/frame/2/her/bli_her_unb_var1_amd.c @@ -0,0 +1,283 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, varname ) \ +\ +void PASTEMAC(ch,varname) \ + ( \ + uplo_t uplo, \ + conj_t conjx, \ + conj_t conjh, \ + dim_t m, \ + ctype* alpha, /* complex alpha allows her variants to also perform syr. */ \ + ctype* x, inc_t incx, \ + ctype* c, inc_t rs_c, inc_t cs_c, \ + cntx_t* cntx \ + ) \ +{ \ + const num_t dt = PASTEMAC(ch,type); \ + ctype* x0; \ + ctype* chi1; \ + ctype* c10t; \ + ctype* gamma11; \ + ctype alpha_local; \ + ctype alpha_chi1; \ + ctype alpha_chi1_chi1; \ + ctype conjx0_chi1; \ + ctype conjx1_chi1; \ + dim_t i; \ + dim_t n_behind; \ + inc_t rs_ct, cs_ct; \ + conj_t conj0, conj1; \ +\ + /* Eliminate unused variable warnings. */ \ + ( void )conj0; \ +\ + /* Make a local copy of alpha and zero out the imaginary component if + we are being invoked as her, since her requires alpha to be real. */ \ + PASTEMAC(ch,copys)( *alpha, alpha_local ); \ + if ( bli_is_conj( conjh ) ) \ + { \ + PASTEMAC(ch,seti0s)( alpha_local ); \ + } \ +\ + /* The algorithm will be expressed in terms of the lower triangular case; + the upper triangular case is supported by swapping the row and column + strides of A and toggling some conj parameters. */ \ + if ( bli_is_lower( uplo ) ) \ + { \ + rs_ct = rs_c; \ + cs_ct = cs_c; \ + } \ + else /* if ( bli_is_upper( uplo ) ) */ \ + { \ + rs_ct = cs_c; \ + cs_ct = rs_c; \ +\ + /* Toggle conjugation of conjx, but only if we are being invoked + as her; for syr, conjx is unchanged. */ \ + conjx = bli_apply_conj( conjh, conjx ); \ + } \ +\ + /* Apply conjh (which carries the conjugation component of the Hermitian + transpose, if applicable) to conjx as needed to arrive at the effective + conjugation for the scalar and vector subproblems. */ \ + conj0 = conjx; \ + conj1 = bli_apply_conj( conjh, conjx ); \ +\ + PASTECH(ch,axpyv_ker_ft) kfp_av; \ +\ + /* Query the context for the kernel function pointer. */ \ + kfp_av = bli_cntx_get_l1v_ker_dt( dt, BLIS_AXPYV_KER, cntx ); \ +\ + for ( i = 0; i < m; ++i ) \ + { \ + n_behind = i; \ + x0 = x + (0 )*incx; \ + chi1 = x + (i )*incx; \ + c10t = c + (i )*rs_ct + (0 )*cs_ct; \ + gamma11 = c + (i )*rs_ct + (i )*cs_ct; \ +\ + /* Apply conjx to chi1. */ \ + PASTEMAC(ch,copycjs)( conj0, *chi1, conjx0_chi1 ); \ + PASTEMAC(ch,copycjs)( conj1, *chi1, conjx1_chi1 ); \ +\ + /* Compute scalar for vector subproblem. */ \ + PASTEMAC(ch,scal2s)( alpha_local, conjx0_chi1, alpha_chi1 ); \ +\ + /* Compute alpha * chi1 * conj(chi1) after chi1 has already been + conjugated, if needed, by conjx. */ \ + PASTEMAC(ch,scal2s)( alpha_chi1, conjx1_chi1, alpha_chi1_chi1 ); \ +\ + /* c10t = c10t + alpha * chi1 * x0'; */ \ + kfp_av \ + ( \ + conj1, \ + n_behind, \ + &alpha_chi1, \ + x0, incx, \ + c10t, cs_ct, \ + cntx \ + ); \ +\ + /* gamma11 = gamma11 + alpha * chi1 * conj(chi1); */ \ + PASTEMAC(ch,adds)( alpha_chi1_chi1, *gamma11 ); \ +\ + /* For her2, explicitly set the imaginary component of gamma11 to + zero. */ \ + if ( bli_is_conj( conjh ) ) \ + PASTEMAC(ch,seti0s)( *gamma11 ); \ + } \ +} + +INSERT_GENTFUNC_BASIC0_SD( her_unb_var1 ) +GENTFUNC( scomplex, c, her_unb_var1 ) + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, varname ) \ +\ +void PASTEMAC(ch,varname) \ + ( \ + uplo_t uplo, \ + conj_t conjx, \ + conj_t conjh, \ + dim_t m, \ + ctype* alpha, /* complex alpha allows her variants to also perform syr. */ \ + ctype* x, inc_t incx, \ + ctype* c, inc_t rs_c, inc_t cs_c, \ + cntx_t* cntx \ + ) \ +{ \ + const num_t dt = PASTEMAC(ch,type); \ + /* Redirect to intrinsic implementation of HER for dcomplex */ \ + if ( bli_cpuid_is_avx_supported() == TRUE && \ + ( rs_c == 1 || cs_c == 1 ) && \ + ( bli_is_upper( uplo ) || bli_is_lower( uplo ) ) && \ + bli_is_conj(conjh) && incx == 1 ) \ + { \ + bli_zher_zen_int_var1 \ + ( \ + uplo, \ + conjx, \ + conjh, \ + m, \ + alpha, \ + x, \ + incx, \ + c, \ + rs_c, \ + cs_c, \ + cntx \ + ); \ + } \ + else \ + { \ + ctype* x0; \ + ctype* chi1; \ + ctype* c10t; \ + ctype* gamma11; \ + ctype alpha_local; \ + ctype alpha_chi1; \ + ctype alpha_chi1_chi1; \ + ctype conjx0_chi1; \ + ctype conjx1_chi1; \ + dim_t i; \ + dim_t n_behind; \ + inc_t rs_ct, cs_ct; \ + conj_t conj0, conj1; \ +\ + /* Eliminate unused variable warnings. */ \ + ( void )conj0; \ +\ + /* Make a local copy of alpha and zero out the imaginary component if + we are being invoked as her, since her requires alpha to be real. */ \ + PASTEMAC(ch,copys)( *alpha, alpha_local ); \ + if ( bli_is_conj( conjh ) ) \ + { \ + PASTEMAC(ch,seti0s)( alpha_local ); \ + } \ +\ + /* The algorithm will be expressed in terms of the lower triangular case; + the upper triangular case is supported by swapping the row and column + strides of A and toggling some conj parameters. */ \ + if ( bli_is_lower( uplo ) ) \ + { \ + rs_ct = rs_c; \ + cs_ct = cs_c; \ + } \ + else /* if ( bli_is_upper( uplo ) ) */ \ + { \ + rs_ct = cs_c; \ + cs_ct = rs_c; \ +\ + /* Toggle conjugation of conjx, but only if we are being invoked + as her; for syr, conjx is unchanged. */ \ + conjx = bli_apply_conj( conjh, conjx ); \ + } \ +\ + /* Apply conjh (which carries the conjugation component of the Hermitian + transpose, if applicable) to conjx as needed to arrive at the effective + conjugation for the scalar and vector subproblems. */ \ + conj0 = conjx; \ + conj1 = bli_apply_conj( conjh, conjx ); \ +\ + PASTECH(ch,axpyv_ker_ft) kfp_av; \ +\ + /* Query the context for the kernel function pointer. */ \ + kfp_av = bli_cntx_get_l1v_ker_dt( dt, BLIS_AXPYV_KER, cntx ); \ +\ + for ( i = 0; i < m; ++i ) \ + { \ + n_behind = i; \ + x0 = x + (0 )*incx; \ + chi1 = x + (i )*incx; \ + c10t = c + (i )*rs_ct + (0 )*cs_ct; \ + gamma11 = c + (i )*rs_ct + (i )*cs_ct; \ +\ + /* Apply conjx to chi1. */ \ + PASTEMAC(ch,copycjs)( conj0, *chi1, conjx0_chi1 ); \ + PASTEMAC(ch,copycjs)( conj1, *chi1, conjx1_chi1 ); \ +\ + /* Compute scalar for vector subproblem. */ \ + PASTEMAC(ch,scal2s)( alpha_local, conjx0_chi1, alpha_chi1 ); \ +\ + /* Compute alpha * chi1 * conj(chi1) after chi1 has already been + conjugated, if needed, by conjx. */ \ + PASTEMAC(ch,scal2s)( alpha_chi1, conjx1_chi1, alpha_chi1_chi1 ); \ +\ + /* c10t = c10t + alpha * chi1 * x0'; */ \ + kfp_av \ + ( \ + conj1, \ + n_behind, \ + &alpha_chi1, \ + x0, incx, \ + c10t, cs_ct, \ + cntx \ + ); \ +\ + /* gamma11 = gamma11 + alpha * chi1 * conj(chi1); */ \ + PASTEMAC(ch,adds)( alpha_chi1_chi1, *gamma11 ); \ +\ + /* For her2, explicitly set the imaginary component of gamma11 to + zero. */ \ + if ( bli_is_conj( conjh ) ) \ + PASTEMAC(ch,seti0s)( *gamma11 ); \ + } \ + } \ +} +GENTFUNC( dcomplex, z, her_unb_var1 ) \ No newline at end of file diff --git a/frame/2/her/bli_her_unb_var2_amd.c b/frame/2/her/bli_her_unb_var2_amd.c new file mode 100644 index 0000000000..f16ef42a76 --- /dev/null +++ b/frame/2/her/bli_her_unb_var2_amd.c @@ -0,0 +1,283 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, varname ) \ +\ +void PASTEMAC(ch,varname) \ + ( \ + uplo_t uplo, \ + conj_t conjx, \ + conj_t conjh, \ + dim_t m, \ + ctype* alpha, /* complex alpha allows her variants to also perform syr. */ \ + ctype* x, inc_t incx, \ + ctype* c, inc_t rs_c, inc_t cs_c, \ + cntx_t* cntx \ + ) \ +{ \ + const num_t dt = PASTEMAC(ch,type); \ + ctype* chi1; \ + ctype* x2; \ + ctype* gamma11; \ + ctype* c21; \ + ctype alpha_local; \ + ctype alpha_chi1; \ + ctype alpha_chi1_chi1; \ + ctype conjx0_chi1; \ + ctype conjx1_chi1; \ + dim_t i; \ + dim_t n_ahead; \ + inc_t rs_ct, cs_ct; \ + conj_t conj0, conj1; \ +\ + /* Eliminate unused variable warnings. */ \ + ( void )conj0; \ +\ + /* Make a local copy of alpha and zero out the imaginary component if + we are being invoked as her, since her requires alpha to be real. */ \ + PASTEMAC(ch,copys)( *alpha, alpha_local ); \ + if ( bli_is_conj( conjh ) ) \ + { \ + PASTEMAC(ch,seti0s)( alpha_local ); \ + } \ +\ + /* The algorithm will be expressed in terms of the lower triangular case; + the upper triangular case is supported by swapping the row and column + strides of A and toggling some conj parameters. */ \ + if ( bli_is_lower( uplo ) ) \ + { \ + rs_ct = rs_c; \ + cs_ct = cs_c; \ + } \ + else /* if ( bli_is_upper( uplo ) ) */ \ + { \ + rs_ct = cs_c; \ + cs_ct = rs_c; \ +\ + /* Toggle conjugation of conjx, but only if we are being invoked + as her; for syr, conjx is unchanged. */ \ + conjx = bli_apply_conj( conjh, conjx ); \ + } \ +\ + /* Apply conjh (which carries the conjugation component of the Hermitian + transpose, if applicable) to conjx as needed to arrive at the effective + conjugation for the scalar and vector subproblems. */ \ + conj0 = bli_apply_conj( conjh, conjx ); \ + conj1 = conjx; \ +\ + PASTECH(ch,axpyv_ker_ft) kfp_av; \ +\ + /* Query the context for the kernel function pointer. */ \ + kfp_av = bli_cntx_get_l1v_ker_dt( dt, BLIS_AXPYV_KER, cntx ); \ +\ + for ( i = 0; i < m; ++i ) \ + { \ + n_ahead = m - i - 1; \ + chi1 = x + (i )*incx; \ + x2 = x + (i+1)*incx; \ + gamma11 = c + (i )*rs_ct + (i )*cs_ct; \ + c21 = c + (i+1)*rs_ct + (i )*cs_ct; \ +\ + /* Apply conjx to chi1. */ \ + PASTEMAC(ch,copycjs)( conj0, *chi1, conjx0_chi1 ); \ + PASTEMAC(ch,copycjs)( conj1, *chi1, conjx1_chi1 ); \ +\ + /* Compute scalar for vector subproblem. */ \ + PASTEMAC(ch,scal2s)( alpha_local, conjx0_chi1, alpha_chi1 ); \ +\ + /* Compute alpha * chi1 * conj(chi1) after chi1 has already been + conjugated, if needed, by conjx. */ \ + PASTEMAC(ch,scal2s)( alpha_chi1, conjx1_chi1, alpha_chi1_chi1 ); \ +\ + /* c21 = c21 + alpha * x2 * conj(chi1); */ \ + kfp_av \ + ( \ + conj1, \ + n_ahead, \ + &alpha_chi1, \ + x2, incx, \ + c21, rs_ct, \ + cntx \ + ); \ +\ + /* gamma11 = gamma11 + alpha * chi1 * conj(chi1); */ \ + PASTEMAC(ch,adds)( alpha_chi1_chi1, *gamma11 ); \ +\ + /* For her, explicitly set the imaginary component of gamma11 to + zero. */ \ + if ( bli_is_conj( conjh ) ) \ + PASTEMAC(ch,seti0s)( *gamma11 ); \ + } \ +} + +INSERT_GENTFUNC_BASIC0_SD( her_unb_var2 ) +GENTFUNC( scomplex, c, her_unb_var2 ) + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, varname ) \ +\ +void PASTEMAC(ch,varname) \ + ( \ + uplo_t uplo, \ + conj_t conjx, \ + conj_t conjh, \ + dim_t m, \ + ctype* alpha, /* complex alpha allows her variants to also perform syr. */ \ + ctype* x, inc_t incx, \ + ctype* c, inc_t rs_c, inc_t cs_c, \ + cntx_t* cntx \ + ) \ +{ \ + const num_t dt = PASTEMAC(ch,type); \ + /* Redirect to intrinsic implementation of HER for unit increment */ \ + if ( bli_cpuid_is_avx_supported() == TRUE && \ + ( rs_c == 1 || cs_c == 1 ) && \ + ( bli_is_upper( uplo ) || bli_is_lower( uplo ) ) && \ + bli_is_conj(conjh) && incx == 1 ) \ + { \ + bli_zher_zen_int_var2 \ + ( \ + uplo, \ + conjx, \ + conjh, \ + m, \ + alpha, \ + x, \ + incx, \ + c, \ + rs_c, \ + cs_c, \ + cntx \ + ); \ + } \ + else \ + { \ + ctype* chi1; \ + ctype* x2; \ + ctype* gamma11; \ + ctype* c21; \ + ctype alpha_local; \ + ctype alpha_chi1; \ + ctype alpha_chi1_chi1; \ + ctype conjx0_chi1; \ + ctype conjx1_chi1; \ + dim_t i; \ + dim_t n_ahead; \ + inc_t rs_ct, cs_ct; \ + conj_t conj0, conj1; \ +\ + /* Eliminate unused variable warnings. */ \ + ( void )conj0; \ +\ + /* Make a local copy of alpha and zero out the imaginary component if + we are being invoked as her, since her requires alpha to be real. */ \ + PASTEMAC(ch,copys)( *alpha, alpha_local ); \ + if ( bli_is_conj( conjh ) ) \ + { \ + PASTEMAC(ch,seti0s)( alpha_local ); \ + } \ +\ + /* The algorithm will be expressed in terms of the lower triangular case; + the upper triangular case is supported by swapping the row and column + strides of A and toggling some conj parameters. */ \ + if ( bli_is_lower( uplo ) ) \ + { \ + rs_ct = rs_c; \ + cs_ct = cs_c; \ + } \ + else /* if ( bli_is_upper( uplo ) ) */ \ + { \ + rs_ct = cs_c; \ + cs_ct = rs_c; \ +\ + /* Toggle conjugation of conjx, but only if we are being invoked + as her; for syr, conjx is unchanged. */ \ + conjx = bli_apply_conj( conjh, conjx ); \ + } \ +\ + /* Apply conjh (which carries the conjugation component of the Hermitian + transpose, if applicable) to conjx as needed to arrive at the effective + conjugation for the scalar and vector subproblems. */ \ + conj0 = bli_apply_conj( conjh, conjx ); \ + conj1 = conjx; \ +\ + PASTECH(ch,axpyv_ker_ft) kfp_av; \ +\ + /* Query the context for the kernel function pointer. */ \ + kfp_av = bli_cntx_get_l1v_ker_dt( dt, BLIS_AXPYV_KER, cntx ); \ +\ + for ( i = 0; i < m; ++i ) \ + { \ + n_ahead = m - i - 1; \ + chi1 = x + (i )*incx; \ + x2 = x + (i+1)*incx; \ + gamma11 = c + (i )*rs_ct + (i )*cs_ct; \ + c21 = c + (i+1)*rs_ct + (i )*cs_ct; \ +\ + /* Apply conjx to chi1. */ \ + PASTEMAC(ch,copycjs)( conj0, *chi1, conjx0_chi1 ); \ + PASTEMAC(ch,copycjs)( conj1, *chi1, conjx1_chi1 ); \ +\ + /* Compute scalar for vector subproblem. */ \ + PASTEMAC(ch,scal2s)( alpha_local, conjx0_chi1, alpha_chi1 ); \ +\ + /* Compute alpha * chi1 * conj(chi1) after chi1 has already been + conjugated, if needed, by conjx. */ \ + PASTEMAC(ch,scal2s)( alpha_chi1, conjx1_chi1, alpha_chi1_chi1 ); \ +\ + /* c21 = c21 + alpha * x2 * conj(chi1); */ \ + kfp_av \ + ( \ + conj1, \ + n_ahead, \ + &alpha_chi1, \ + x2, incx, \ + c21, rs_ct, \ + cntx \ + ); \ +\ + /* gamma11 = gamma11 + alpha * chi1 * conj(chi1); */ \ + PASTEMAC(ch,adds)( alpha_chi1_chi1, *gamma11 ); \ +\ + /* For her, explicitly set the imaginary component of gamma11 to + zero. */ \ + if ( bli_is_conj( conjh ) ) \ + PASTEMAC(ch,seti0s)( *gamma11 ); \ + } \ + } \ +} +GENTFUNC( dcomplex, z, her_unb_var2 ) \ No newline at end of file diff --git a/frame/2/her2/CMakeLists.txt b/frame/2/her2/CMakeLists.txt index 83629df8f5..cfdeb2480d 100644 --- a/frame/2/her2/CMakeLists.txt +++ b/frame/2/her2/CMakeLists.txt @@ -12,7 +12,8 @@ target_sources("${PROJECT_NAME}" # Select AMD specific sources for AMD configurations. if(${TARGET_ARCH} STREQUAL zen OR ${TARGET_ARCH} STREQUAL zen2 OR - ${TARGET_ARCH} STREQUAL zen3 OR + ${TARGET_ARCH} STREQUAL zen3 OR + ${TARGET_ARCH} STREQUAL zen4 OR ${TARGET_ARCH} STREQUAL amdzen) target_sources("${PROJECT_NAME}" PRIVATE diff --git a/frame/2/trsv/CMakeLists.txt b/frame/2/trsv/CMakeLists.txt index b07389340e..f1aacc745c 100644 --- a/frame/2/trsv/CMakeLists.txt +++ b/frame/2/trsv/CMakeLists.txt @@ -10,7 +10,8 @@ target_sources("${PROJECT_NAME}" # Select AMD specific sources for AMD configurations. if(${TARGET_ARCH} STREQUAL zen OR ${TARGET_ARCH} STREQUAL zen2 OR - ${TARGET_ARCH} STREQUAL zen3 OR + ${TARGET_ARCH} STREQUAL zen3 OR + ${TARGET_ARCH} STREQUAL zen4 OR ${TARGET_ARCH} STREQUAL amdzen) target_sources("${PROJECT_NAME}" PRIVATE diff --git a/frame/3/CMakeLists.txt b/frame/3/CMakeLists.txt index e9d7da7b8e..734622344a 100644 --- a/frame/3/CMakeLists.txt +++ b/frame/3/CMakeLists.txt @@ -30,7 +30,8 @@ target_sources("${PROJECT_NAME}" # Select AMD specific sources for AMD configurations. if(${TARGET_ARCH} STREQUAL zen OR ${TARGET_ARCH} STREQUAL zen2 OR - ${TARGET_ARCH} STREQUAL zen3 OR + ${TARGET_ARCH} STREQUAL zen3 OR + ${TARGET_ARCH} STREQUAL zen4 OR ${TARGET_ARCH} STREQUAL amdzen) target_sources("${PROJECT_NAME}" PRIVATE diff --git a/frame/3/bli_l3_packm.c b/frame/3/bli_l3_packm.c index d6efb593cc..1134bdc1fd 100644 --- a/frame/3/bli_l3_packm.c +++ b/frame/3/bli_l3_packm.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2022, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -47,6 +47,8 @@ void bli_l3_packm { packbuf_t pack_buf_type; mem_t* cntl_mem_p; + mem_t* local_mem_p; + mem_t local_mem_s; siz_t size_needed; // FGVZ: Not sure why we need this barrier, but we do. @@ -80,9 +82,6 @@ void bli_l3_packm // all threads in the chief's thread group. if ( bli_mem_is_unalloc( cntl_mem_p ) ) { - mem_t* local_mem_p; - mem_t local_mem_s; - if ( bli_thread_am_ochief( thread ) ) { #ifdef BLIS_ENABLE_MEM_TRACING @@ -110,9 +109,6 @@ void bli_l3_packm } else // ( bli_mem_is_alloc( cntl_mem_p ) ) { - mem_t* local_mem_p; - mem_t local_mem_s; - // If the mem_t entry in the control tree does NOT contain a NULL // buffer, then a block has already been acquired from the memory // broker and cached in the control tree. diff --git a/frame/3/bli_l3_packm.h b/frame/3/bli_l3_packm.h index 696dabf593..50a6e2f9c1 100644 --- a/frame/3/bli_l3_packm.h +++ b/frame/3/bli_l3_packm.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2022, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/frame/3/bli_l3_sup.c b/frame/3/bli_l3_sup.c index d23df8c1e5..867ccd200c 100644 --- a/frame/3/bli_l3_sup.c +++ b/frame/3/bli_l3_sup.c @@ -152,8 +152,7 @@ err_t bli_gemmsup // Query the small/unpacked handler from the context and invoke it. gemmsup_oft gemmsup_fp = bli_cntx_get_l3_sup_handler( BLIS_GEMM, cntx ); - return - gemmsup_fp + err_t ret_gemmsup_fp = gemmsup_fp ( alpha, a, @@ -165,6 +164,7 @@ err_t bli_gemmsup ); AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_2); + return ret_gemmsup_fp; } err_t bli_gemmtsup @@ -285,8 +285,7 @@ printf( "dims: %d %d %d (threshs: %d %d %d)\n", // Query the small/unpacked handler from the context and invoke it. gemmtsup_oft gemmtsup_fp = bli_cntx_get_l3_sup_handler( BLIS_GEMMT, cntx ); - return - gemmtsup_fp + err_t ret_gemmtsup_fp = gemmtsup_fp ( alpha, a, @@ -298,6 +297,7 @@ printf( "dims: %d %d %d (threshs: %d %d %d)\n", ); AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_2); + return ret_gemmtsup_fp; } err_t bli_syrksup @@ -414,8 +414,7 @@ printf( "dims: %d %d %d (threshs: %d %d %d)\n", // Query the small/unpacked handler from the context and invoke it. gemmtsup_oft gemmtsup_fp = bli_cntx_get_l3_sup_handler( BLIS_GEMMT, cntx ); - return - gemmtsup_fp + err_t ret_gemmtsup_fp = gemmtsup_fp ( alpha, a, @@ -427,4 +426,5 @@ printf( "dims: %d %d %d (threshs: %d %d %d)\n", ); AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_2); + return ret_gemmtsup_fp; } diff --git a/frame/3/bli_l3_sup_int.c b/frame/3/bli_l3_sup_int.c index 909f480599..e215a7d825 100644 --- a/frame/3/bli_l3_sup_int.c +++ b/frame/3/bli_l3_sup_int.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2019-21, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2019-22, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/frame/3/bli_l3_sup_int_amd.c b/frame/3/bli_l3_sup_int_amd.c index e00cc54ad0..b226b135d0 100644 --- a/frame/3/bli_l3_sup_int_amd.c +++ b/frame/3/bli_l3_sup_int_amd.c @@ -52,21 +52,15 @@ err_t bli_gemmsup_int const dim_t m = bli_obj_length( c ); const dim_t n = bli_obj_width( c ); const dim_t k = bli_obj_width( a ); - const dim_t MR = bli_cntx_get_blksz_def_dt( dt, BLIS_MR, cntx ); - const dim_t NR = bli_cntx_get_blksz_def_dt( dt, BLIS_NR, cntx ); + const dim_t MR = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_MR, cntx ); + const dim_t NR = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_NR, cntx ); + const dim_t KC = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_KC, cntx ); const bool auto_factor = bli_rntm_auto_factor( rntm ); const dim_t n_threads = bli_rntm_num_threads( rntm ); - + bool use_pb = FALSE; dim_t jc_new; dim_t ic_new; - - //bli_gemmsup_ref_var2 - //bli_gemmsup_ref_var1 - #if 0 - bli_gemmsup_ref_var1n - #else - #endif const stor3_t stor_id = bli_obj_stor3_from_strides( c, a, b ); const bool is_rrr_rrc_rcr_crr = ( stor_id == BLIS_RRR || stor_id == BLIS_RRC || @@ -96,6 +90,9 @@ err_t bli_gemmsup_int const dim_t mu = m / MR; const dim_t nu = n / NR; + // Heuristic to decide whether to use 1n variant or not for sgemm. + use_pb = ( ( nu >= ( 4 * mu ) ) && ( k >= KC ) ) ? TRUE : FALSE; + // If the parallel thread factorization was automatic, we update it // with a new factorization based on the matrix dimensions in units // of micropanels. However in case smart threading is enabled, @@ -113,15 +110,54 @@ err_t bli_gemmsup_int bli_l3_sup_thrinfo_update_root( rntm, thread ); } - /*Enable packing for B matrix for higher sizes*/ + //Enable packing for B matrix for higher sizes if(bli_is_float(dt) && (n_threads==1)) { if((m > 240) && (k > 240) && (n > 240)) - bli_rntm_set_pack_b( 1, rntm ); + bli_rntm_set_pack_b( 1, rntm );//packb } - bli_gemmsup_ref_var2m( BLIS_NO_TRANSPOSE, - alpha, a, b, beta, c, - stor_id, cntx, rntm, thread ); + //Enable packing of B matrix for complex data type + if (bli_is_dcomplex(dt) && (n_threads == 1)) + { + if ((m > 55) && (k > 55) && (n > 55)) + bli_rntm_set_pack_b(1, rntm);//packb + } + + //Enable packing of B matrix for double data type when dims at per + //thread level are above caches and enable packing of A when transA + //(RRC or CRC storage ids) to avoid rd kernels + if(bli_is_double(dt)) + { + dim_t m_pt = (m/bli_rntm_ways_for( BLIS_MC, rntm )); + dim_t n_pt = (n/bli_rntm_ways_for( BLIS_NC, rntm )); + + if(k > 120) + { + if(((m_pt > 320) && (n_pt > 120)) || ((m_pt > 120) && (n_pt > 320))) + { + bli_rntm_set_pack_b(1, rntm);//packb + + if(stor_id==BLIS_RRC || stor_id==BLIS_CRC) + bli_rntm_set_pack_a(1, rntm);//packa + } + } + } + + // Using the 1n kernel (B broadcast) gave better performance for sgemm + // in single-thread scenario, given the number of n panels are + // sufficiently larger than m panels. + if ( bli_is_float( dt ) && ( n_threads == 1 ) && ( use_pb == TRUE ) ) + { + bli_gemmsup_ref_var1n( BLIS_NO_TRANSPOSE, + alpha, a, b, beta, c, + stor_id, cntx, rntm, thread ); + } + else + { + bli_gemmsup_ref_var2m( BLIS_NO_TRANSPOSE, + alpha, a, b, beta, c, + stor_id, cntx, rntm, thread ); + } } else { @@ -132,6 +168,8 @@ err_t bli_gemmsup_int const dim_t mu = n / MR; // the n becomes m after a transposition const dim_t nu = m / NR; // the m becomes n after a transposition + use_pb = ( ( nu >= ( 4 * mu ) ) && ( k >= KC ) ) ? TRUE : FALSE; + if ( auto_factor ) { // In the block-panel algorithm, the m dimension is parallelized @@ -149,12 +187,48 @@ err_t bli_gemmsup_int * becomes pack B inside var2m because this is transpose case*/ if(bli_is_float(dt) && (n_threads==1)) { if((m > 240) && (k > 240) && (n > 240)) - bli_rntm_set_pack_a( 1, rntm ); + bli_rntm_set_pack_a( 1, rntm );//packb + } + + /*Enable packing of A matrix for complex data type*/ + if (bli_is_dcomplex(dt) && (n_threads == 1)) + { + if ((m > 55) && (k > 55) && (n > 55)) + bli_rntm_set_pack_a(1, rntm);//packb } - bli_gemmsup_ref_var2m( BLIS_TRANSPOSE, - alpha, a, b, beta, c, - stor_id, cntx, rntm, thread ); + //Enable packing of B matrix for double data type when dims at per + //thread level are above caches and enable packing of A when transA + //(RRC or CRC storage ids) to avoid rd kernels + if(bli_is_double(dt)) + { + dim_t m_pt = (m/bli_rntm_ways_for( BLIS_NC, rntm )); + dim_t n_pt = (n/bli_rntm_ways_for( BLIS_MC, rntm )); + + if(k > 120) + { + if(((m_pt > 320) && (n_pt > 120)) || ((m_pt > 120) && (n_pt > 320))) + { + bli_rntm_set_pack_a(1, rntm);//packb + + if(stor_id==BLIS_RRC || stor_id==BLIS_CRC) + bli_rntm_set_pack_b(1, rntm);//packa + } + } + } + + if ( bli_is_float( dt ) && ( n_threads == 1 ) && ( use_pb == TRUE ) ) + { + bli_gemmsup_ref_var1n( BLIS_TRANSPOSE, + alpha, a, b, beta, c, + stor_id, cntx, rntm, thread ); + } + else + { + bli_gemmsup_ref_var2m( BLIS_TRANSPOSE, + alpha, a, b, beta, c, + stor_id, cntx, rntm, thread ); + } } AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_4); diff --git a/frame/3/bli_l3_sup_packm_a.c b/frame/3/bli_l3_sup_packm_a.c index 6933b6906f..196dfae0b5 100644 --- a/frame/3/bli_l3_sup_packm_a.c +++ b/frame/3/bli_l3_sup_packm_a.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2022, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -58,7 +58,7 @@ void PASTEMAC(ch,opname) \ } \ else /* if ( will_pack == TRUE ) */ \ { \ - /* NOTE: This is "rounding up" of the last upanel is actually optional + /* NOTE: This "rounding up" of the last upanel is actually optional for the rrc/crc cases, but absolutely necessary for the other cases since we NEED that last micropanel to have the same ldim (cs_p) as the other micropanels. Why? So that millikernels can use the same diff --git a/frame/3/bli_l3_sup_packm_b.c b/frame/3/bli_l3_sup_packm_b.c index 20c41b6b0b..b733122825 100644 --- a/frame/3/bli_l3_sup_packm_b.c +++ b/frame/3/bli_l3_sup_packm_b.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2022, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -58,7 +58,7 @@ void PASTEMAC(ch,opname) \ } \ else /* if ( will_pack == TRUE ) */ \ { \ - /* NOTE: This is "rounding up" of the last upanel is actually optional + /* NOTE: This "rounding up" of the last upanel is actually optional for the rrc/crc cases, but absolutely necessary for the other cases since we NEED that last micropanel to have the same ldim (cs_p) as the other micropanels. Why? So that millikernels can use the same @@ -285,15 +285,15 @@ void PASTEMAC(ch,opname) \ } \ else \ { \ - /* All other stor3_t ids: pack A to column-stored row-panels. */ \ + /* All other stor3_t ids: pack B to row-stored column-panels. */ \ *rs_p = nr; \ *cs_p = 1; \ \ *pd_p = nr; \ *ps_p = k * nr; \ \ - /* Set the schema to "packed row panels" to indicate packing to - conventional column-stored row panels. */ \ + /* Set the schema to "packed column panels" to indicate packing to + conventional row-stored column panels. */ \ *schema = BLIS_PACKED_COL_PANELS; \ } \ \ diff --git a/frame/3/gemm/CMakeLists.txt b/frame/3/gemm/CMakeLists.txt index 825dd745ca..8969680031 100644 --- a/frame/3/gemm/CMakeLists.txt +++ b/frame/3/gemm/CMakeLists.txt @@ -18,7 +18,8 @@ target_sources("${PROJECT_NAME}" # Select AMD specific sources for AMD configurations. if(${TARGET_ARCH} STREQUAL zen OR ${TARGET_ARCH} STREQUAL zen2 OR -${TARGET_ARCH} STREQUAL zen3 OR +${TARGET_ARCH} STREQUAL zen3 OR +${TARGET_ARCH} STREQUAL zen4 OR ${TARGET_ARCH} STREQUAL amdzen) target_sources("${PROJECT_NAME}" PRIVATE diff --git a/frame/3/gemm/bli_gemm_front.c b/frame/3/gemm/bli_gemm_front.c index a9bada995d..063f40ff9c 100644 --- a/frame/3/gemm/bli_gemm_front.c +++ b/frame/3/gemm/bli_gemm_front.c @@ -294,88 +294,4 @@ void bli_gemm_front // ----------------------------------------------------------------------------- -#if 0 - if ( bli_obj_dt( a ) != bli_obj_dt( b ) || - bli_obj_dt( a ) != bli_obj_dt( c ) || - bli_obj_comp_prec( c ) != bli_obj_prec( c ) ) - { - const bool a_is_real = bli_obj_is_real( a ); - const bool a_is_comp = bli_obj_is_complex( a ); - const bool b_is_real = bli_obj_is_real( b ); - const bool b_is_comp = bli_obj_is_complex( b ); - const bool c_is_real = bli_obj_is_real( c ); - const bool c_is_comp = bli_obj_is_complex( c ); - - const bool a_is_single = bli_obj_is_single_prec( a ); - const bool a_is_double = bli_obj_is_double_prec( a ); - const bool b_is_single = bli_obj_is_single_prec( b ); - const bool b_is_double = bli_obj_is_double_prec( b ); - const bool c_is_single = bli_obj_is_single_prec( c ); - const bool c_is_double = bli_obj_is_double_prec( c ); - - const bool comp_single = bli_obj_comp_prec( c ) == BLIS_SINGLE_PREC; - const bool comp_double = bli_obj_comp_prec( c ) == BLIS_DOUBLE_PREC; - - const bool mixeddomain = bli_obj_domain( c ) != bli_obj_domain( a ) || - bli_obj_domain( c ) != bli_obj_domain( b ); - - ( void )a_is_real; ( void )a_is_comp; - ( void )b_is_real; ( void )b_is_comp; - ( void )c_is_real; ( void )c_is_comp; - ( void )a_is_single; ( void )a_is_double; - ( void )b_is_single; ( void )b_is_double; - ( void )c_is_single; ( void )c_is_double; - ( void )comp_single; ( void )comp_double; - - if ( - //( c_is_comp && a_is_comp && b_is_real ) || - //( c_is_comp && a_is_real && b_is_comp ) || - //( c_is_real && a_is_comp && b_is_comp ) || - //( c_is_comp && a_is_real && b_is_real ) || - //( c_is_real && a_is_comp && b_is_real ) || - //( c_is_real && a_is_real && b_is_comp ) || - //FALSE - TRUE - ) - { - if ( - ( c_is_single && a_is_single && b_is_single && mixeddomain ) || - ( c_is_single && a_is_single && b_is_single && comp_single ) || - ( c_is_single && a_is_single && b_is_single && comp_double ) || - ( c_is_single && a_is_single && b_is_double ) || - ( c_is_single && a_is_double && b_is_single ) || - ( c_is_double && a_is_single && b_is_single ) || - ( c_is_single && a_is_double && b_is_double ) || - ( c_is_double && a_is_single && b_is_double ) || - ( c_is_double && a_is_double && b_is_single ) || - ( c_is_double && a_is_double && b_is_double && comp_single ) || - ( c_is_double && a_is_double && b_is_double && comp_double ) || - ( c_is_double && a_is_double && b_is_double && mixeddomain ) || - FALSE - ) - bli_gemm_md_front( alpha, a, b, beta, c, cntx, cntl ); - else - bli_gemm_md_zgemm( alpha, a, b, beta, c, cntx, cntl ); - } - else - bli_gemm_md_zgemm( alpha, a, b, beta, c, cntx, cntl ); - return; - } -#else -#if 0 - // If any of the storage datatypes differ, or if the execution precision - // differs from the storage precision of C, utilize the mixed datatype - // code path. - // NOTE: We could check the exec dt against the storage dt of C, but for - // now we don't support the caller setting the execution domain - // explicitly. - if ( bli_obj_dt( a ) != bli_obj_dt( b ) || - bli_obj_dt( a ) != bli_obj_dt( c ) || - bli_obj_comp_prec( c ) != bli_obj_prec( c ) ) - { - bli_gemm_md_front( alpha, a, b, beta, c, cntx, cntl ); - return; - } -#endif -#endif diff --git a/frame/3/gemm/bli_gemm_front_amd.c b/frame/3/gemm/bli_gemm_front_amd.c index 34b41f0568..b15d906dd8 100644 --- a/frame/3/gemm/bli_gemm_front_amd.c +++ b/frame/3/gemm/bli_gemm_front_amd.c @@ -319,89 +319,3 @@ void bli_gemm_front } // ----------------------------------------------------------------------------- - -#if 0 - if ( bli_obj_dt( a ) != bli_obj_dt( b ) || - bli_obj_dt( a ) != bli_obj_dt( c ) || - bli_obj_comp_prec( c ) != bli_obj_prec( c ) ) - { - const bool a_is_real = bli_obj_is_real( a ); - const bool a_is_comp = bli_obj_is_complex( a ); - const bool b_is_real = bli_obj_is_real( b ); - const bool b_is_comp = bli_obj_is_complex( b ); - const bool c_is_real = bli_obj_is_real( c ); - const bool c_is_comp = bli_obj_is_complex( c ); - - const bool a_is_single = bli_obj_is_single_prec( a ); - const bool a_is_double = bli_obj_is_double_prec( a ); - const bool b_is_single = bli_obj_is_single_prec( b ); - const bool b_is_double = bli_obj_is_double_prec( b ); - const bool c_is_single = bli_obj_is_single_prec( c ); - const bool c_is_double = bli_obj_is_double_prec( c ); - - const bool comp_single = bli_obj_comp_prec( c ) == BLIS_SINGLE_PREC; - const bool comp_double = bli_obj_comp_prec( c ) == BLIS_DOUBLE_PREC; - - const bool mixeddomain = bli_obj_domain( c ) != bli_obj_domain( a ) || - bli_obj_domain( c ) != bli_obj_domain( b ); - - ( void )a_is_real; ( void )a_is_comp; - ( void )b_is_real; ( void )b_is_comp; - ( void )c_is_real; ( void )c_is_comp; - ( void )a_is_single; ( void )a_is_double; - ( void )b_is_single; ( void )b_is_double; - ( void )c_is_single; ( void )c_is_double; - ( void )comp_single; ( void )comp_double; - - if ( - //( c_is_comp && a_is_comp && b_is_real ) || - //( c_is_comp && a_is_real && b_is_comp ) || - //( c_is_real && a_is_comp && b_is_comp ) || - //( c_is_comp && a_is_real && b_is_real ) || - //( c_is_real && a_is_comp && b_is_real ) || - //( c_is_real && a_is_real && b_is_comp ) || - //FALSE - TRUE - ) - { - if ( - ( c_is_single && a_is_single && b_is_single && mixeddomain ) || - ( c_is_single && a_is_single && b_is_single && comp_single ) || - ( c_is_single && a_is_single && b_is_single && comp_double ) || - ( c_is_single && a_is_single && b_is_double ) || - ( c_is_single && a_is_double && b_is_single ) || - ( c_is_double && a_is_single && b_is_single ) || - ( c_is_single && a_is_double && b_is_double ) || - ( c_is_double && a_is_single && b_is_double ) || - ( c_is_double && a_is_double && b_is_single ) || - ( c_is_double && a_is_double && b_is_double && comp_single ) || - ( c_is_double && a_is_double && b_is_double && comp_double ) || - ( c_is_double && a_is_double && b_is_double && mixeddomain ) || - FALSE - ) - bli_gemm_md_front( alpha, a, b, beta, c, cntx, cntl ); - else - bli_gemm_md_zgemm( alpha, a, b, beta, c, cntx, cntl ); - } - else - bli_gemm_md_zgemm( alpha, a, b, beta, c, cntx, cntl ); - return; - } -#else -#if 0 - // If any of the storage datatypes differ, or if the execution precision - // differs from the storage precision of C, utilize the mixed datatype - // code path. - // NOTE: We could check the exec dt against the storage dt of C, but for - // now we don't support the caller setting the execution domain - // explicitly. - if ( bli_obj_dt( a ) != bli_obj_dt( b ) || - bli_obj_dt( a ) != bli_obj_dt( c ) || - bli_obj_comp_prec( c ) != bli_obj_prec( c ) ) - { - bli_gemm_md_front( alpha, a, b, beta, c, cntx, cntl ); - return; - } -#endif -#endif - diff --git a/frame/3/gemm/bli_gemm_md.c b/frame/3/gemm/bli_gemm_md.c index 0f82b15f3e..68298c71ca 100644 --- a/frame/3/gemm/bli_gemm_md.c +++ b/frame/3/gemm/bli_gemm_md.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2017 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2017 - 2022, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -156,92 +156,20 @@ mddm_t bli_gemm_md_ccr cntx_t** cntx ) { - mddm_t doms; - - // We assume that the requested computation domain is complex. - //dom_t dom_comp_in = bli_obj_comp_domain( c ); - //dom_t dom_comp_in = BLIS_COMPLEX; - - // For ccr, the computation (ukernel) will be real, but the execution - // will appear complex to other parts of the implementation. - doms.comp = BLIS_REAL; - doms.exec = BLIS_COMPLEX; - - // Here we construct the computation datatype, which for the ccr case - // is equal to the real projection of the execution datatype, and use - // that computation datatype to query the corresponding ukernel output - // preference. - const num_t dt = BLIS_REAL | bli_obj_comp_prec( c ); - const bool row_pref - = bli_cntx_l3_nat_ukr_prefers_rows_dt( dt, BLIS_GEMM_UKR, *cntx ); - - // We can only perform this case of mixed-domain gemm, C += A*B where - // B is real, if the microkernel prefers column output. If it prefers - // row output, we must induce a transposition and perform C += A*B - // where A (formerly B) is real. - if ( row_pref ) - { - bli_obj_swap( a, b ); - - bli_obj_induce_trans( a ); - bli_obj_induce_trans( b ); - bli_obj_induce_trans( c ); - - return bli_gemm_md_crc( a, b, beta, c, cntx_local, cntx ); - } - // Create a local copy of the context and then prepare to use this // context instead of the one passed in. *cntx_local = **cntx; *cntx = cntx_local; - // Copy the real domain blocksizes into the slots of their complex - // counterparts. - blksz_t* blksz_mr = bli_cntx_get_blksz( BLIS_MR, *cntx ); - blksz_t* blksz_nr = bli_cntx_get_blksz( BLIS_NR, *cntx ); - blksz_t* blksz_mc = bli_cntx_get_blksz( BLIS_MC, *cntx ); - blksz_t* blksz_nc = bli_cntx_get_blksz( BLIS_NC, *cntx ); - blksz_t* blksz_kc = bli_cntx_get_blksz( BLIS_KC, *cntx ); - - bli_blksz_copy_dt( BLIS_FLOAT, blksz_mr, BLIS_SCOMPLEX, blksz_mr ); - bli_blksz_copy_dt( BLIS_DOUBLE, blksz_mr, BLIS_DCOMPLEX, blksz_mr ); - - bli_blksz_copy_dt( BLIS_FLOAT, blksz_nr, BLIS_SCOMPLEX, blksz_nr ); - bli_blksz_copy_dt( BLIS_DOUBLE, blksz_nr, BLIS_DCOMPLEX, blksz_nr ); - - bli_blksz_copy_dt( BLIS_FLOAT, blksz_mc, BLIS_SCOMPLEX, blksz_mc ); - bli_blksz_copy_dt( BLIS_DOUBLE, blksz_mc, BLIS_DCOMPLEX, blksz_mc ); - - bli_blksz_copy_dt( BLIS_FLOAT, blksz_nc, BLIS_SCOMPLEX, blksz_nc ); - bli_blksz_copy_dt( BLIS_DOUBLE, blksz_nc, BLIS_DCOMPLEX, blksz_nc ); - - bli_blksz_copy_dt( BLIS_FLOAT, blksz_kc, BLIS_SCOMPLEX, blksz_kc ); - bli_blksz_copy_dt( BLIS_DOUBLE, blksz_kc, BLIS_DCOMPLEX, blksz_kc ); - - // Halve both the real and complex MR's (which are both real MR's). - bli_blksz_scale_def_max( 1, 2, BLIS_FLOAT, blksz_mr ); - bli_blksz_scale_def_max( 1, 2, BLIS_DOUBLE, blksz_mr ); - bli_blksz_scale_def_max( 1, 2, BLIS_SCOMPLEX, blksz_mr ); - bli_blksz_scale_def_max( 1, 2, BLIS_DCOMPLEX, blksz_mr ); - - // Halve both the real and complex MC's (which are both real MC's). - bli_blksz_scale_def_max( 1, 2, BLIS_FLOAT, blksz_mc ); - bli_blksz_scale_def_max( 1, 2, BLIS_DOUBLE, blksz_mc ); - bli_blksz_scale_def_max( 1, 2, BLIS_SCOMPLEX, blksz_mc ); - bli_blksz_scale_def_max( 1, 2, BLIS_DCOMPLEX, blksz_mc ); - - // Use the default pack schemas in the context. - - // static func_t* bli_cntx_get_l3_vir_ukrs( l3ukr_t ukr_id, cntx_t* cntx ) - func_t* l3_vir_ukrs = bli_cntx_get_l3_vir_ukrs( BLIS_GEMM_UKR, *cntx ); + //we must induce a transposition and perform C += A*B + // where A (formerly B) is real. + bli_obj_swap( a, b ); - // Rather than check which complex datatype dt_comp refers to, we set - // the mixed-domain virtual microkernel for both types. - bli_func_set_dt( bli_cgemm_md_c2r_ref, BLIS_SCOMPLEX, l3_vir_ukrs ); - bli_func_set_dt( bli_zgemm_md_c2r_ref, BLIS_DCOMPLEX, l3_vir_ukrs ); + bli_obj_induce_trans( a ); + bli_obj_induce_trans( b ); + bli_obj_induce_trans( c ); - // Return the computation and execution domains. - return doms; + return bli_gemm_md_crc( a, b, beta, c, cntx_local, cntx ); } // ----------------------------------------------------------------------------- @@ -268,29 +196,6 @@ mddm_t bli_gemm_md_crc doms.comp = BLIS_REAL; doms.exec = BLIS_COMPLEX; - // Here we construct the computation datatype, which for the crc case - // is equal to the real projection of the execution datatype, and use - // that computation datatype to query the corresponding ukernel output - // preference. - const num_t dt = BLIS_REAL | bli_obj_comp_prec( c ); - const bool col_pref - = bli_cntx_l3_nat_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, *cntx ); - - // We can only perform this case of mixed-domain gemm, C += A*B where - // A is real, if the microkernel prefers row output. If it prefers - // column output, we must induce a transposition and perform C += A*B - // where B (formerly A) is real. - if ( col_pref ) - { - bli_obj_swap( a, b ); - - bli_obj_induce_trans( a ); - bli_obj_induce_trans( b ); - bli_obj_induce_trans( c ); - - return bli_gemm_md_ccr( a, b, beta, c, cntx_local, cntx ); - } - // Create a local copy of the context and then prepare to use this // context instead of the one passed in. *cntx_local = **cntx; diff --git a/frame/3/gemmt/bli_gemmt_front.c b/frame/3/gemmt/bli_gemmt_front.c index 86940c1bd2..b2155a0bcd 100644 --- a/frame/3/gemmt/bli_gemmt_front.c +++ b/frame/3/gemmt/bli_gemmt_front.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020-2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -278,88 +278,3 @@ void bli_gemmt_front } // ----------------------------------------------------------------------------- - -#if 0 - if ( bli_obj_dt( a ) != bli_obj_dt( b ) || - bli_obj_dt( a ) != bli_obj_dt( c ) || - bli_obj_comp_prec( c ) != bli_obj_prec( c ) ) - { - const bool a_is_real = bli_obj_is_real( a ); - const bool a_is_comp = bli_obj_is_complex( a ); - const bool b_is_real = bli_obj_is_real( b ); - const bool b_is_comp = bli_obj_is_complex( b ); - const bool c_is_real = bli_obj_is_real( c ); - const bool c_is_comp = bli_obj_is_complex( c ); - - const bool a_is_single = bli_obj_is_single_prec( a ); - const bool a_is_double = bli_obj_is_double_prec( a ); - const bool b_is_single = bli_obj_is_single_prec( b ); - const bool b_is_double = bli_obj_is_double_prec( b ); - const bool c_is_single = bli_obj_is_single_prec( c ); - const bool c_is_double = bli_obj_is_double_prec( c ); - - const bool comp_single = bli_obj_comp_prec( c ) == BLIS_SINGLE_PREC; - const bool comp_double = bli_obj_comp_prec( c ) == BLIS_DOUBLE_PREC; - - const bool mixeddomain = bli_obj_domain( c ) != bli_obj_domain( a ) || - bli_obj_domain( c ) != bli_obj_domain( b ); - - ( void )a_is_real; ( void )a_is_comp; - ( void )b_is_real; ( void )b_is_comp; - ( void )c_is_real; ( void )c_is_comp; - ( void )a_is_single; ( void )a_is_double; - ( void )b_is_single; ( void )b_is_double; - ( void )c_is_single; ( void )c_is_double; - ( void )comp_single; ( void )comp_double; - - if ( - //( c_is_comp && a_is_comp && b_is_real ) || - //( c_is_comp && a_is_real && b_is_comp ) || - //( c_is_real && a_is_comp && b_is_comp ) || - //( c_is_comp && a_is_real && b_is_real ) || - //( c_is_real && a_is_comp && b_is_real ) || - //( c_is_real && a_is_real && b_is_comp ) || - //FALSE - TRUE - ) - { - if ( - ( c_is_single && a_is_single && b_is_single && mixeddomain ) || - ( c_is_single && a_is_single && b_is_single && comp_single ) || - ( c_is_single && a_is_single && b_is_single && comp_double ) || - ( c_is_single && a_is_single && b_is_double ) || - ( c_is_single && a_is_double && b_is_single ) || - ( c_is_double && a_is_single && b_is_single ) || - ( c_is_single && a_is_double && b_is_double ) || - ( c_is_double && a_is_single && b_is_double ) || - ( c_is_double && a_is_double && b_is_single ) || - ( c_is_double && a_is_double && b_is_double && comp_single ) || - ( c_is_double && a_is_double && b_is_double && comp_double ) || - ( c_is_double && a_is_double && b_is_double && mixeddomain ) || - FALSE - ) - bli_gemm_md_front( alpha, a, b, beta, c, cntx, cntl ); - else - bli_gemm_md_zgemm( alpha, a, b, beta, c, cntx, cntl ); - } - else - bli_gemm_md_zgemm( alpha, a, b, beta, c, cntx, cntl ); - return; - } -#else -#if 0 - // If any of the storage datatypes differ, or if the execution precision - // differs from the storage precision of C, utilize the mixed datatype - // code path. - // NOTE: We could check the exec dt against the storage dt of C, but for - // now we don't support the caller setting the execution domain - // explicitly. - if ( bli_obj_dt( a ) != bli_obj_dt( b ) || - bli_obj_dt( a ) != bli_obj_dt( c ) || - bli_obj_comp_prec( c ) != bli_obj_prec( c ) ) - { - bli_gemm_md_front( alpha, a, b, beta, c, cntx, cntl ); - return; - } -#endif -#endif diff --git a/frame/3/gemmt/bli_gemmt_sup_var1n2m.c b/frame/3/gemmt/bli_gemmt_sup_var1n2m.c index 382ca6f67d..a026ed8d39 100644 --- a/frame/3/gemmt/bli_gemmt_sup_var1n2m.c +++ b/frame/3/gemmt/bli_gemmt_sup_var1n2m.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2020 - 21, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020 - 22, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -55,6 +55,63 @@ typedef void (*FUNCPTR_T) rntm_t* restrict rntm, thrinfo_t* restrict thread ); + + +// Declaration of gemmt specific kernels function pointer +// This is aligned to bli_dgemmsup_rv_haswell_asm_6x8m function protype. +typedef void (*gemmt_ker_ft) + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ); + +//Look-up table for Gemmt Upper Variant Kernels +gemmt_ker_ft ker_fpus[14] = + { + bli_dgemmsup_rv_haswell_asm_6x8m_0x0_U, + bli_dgemmsup_rv_haswell_asm_6x8m_6x0_U, + bli_dgemmsup_rv_haswell_asm_6x8m_6x8_U, + bli_dgemmsup_rv_haswell_asm_6x8m_12x8_U, + bli_dgemmsup_rv_haswell_asm_6x8m_12x16_U, + bli_dgemmsup_rv_haswell_asm_6x8m_18x16_U, + bli_dgemmsup_rv_haswell_asm_6x8m_0x0_combined_U, + bli_dgemmsup_rd_haswell_asm_6x8m_0x0_U, + bli_dgemmsup_rd_haswell_asm_6x8m_6x0_U, + bli_dgemmsup_rd_haswell_asm_6x8m_6x8_U, + bli_dgemmsup_rd_haswell_asm_6x8m_12x8_U, + bli_dgemmsup_rd_haswell_asm_6x8m_12x16_U, + bli_dgemmsup_rd_haswell_asm_6x8m_18x16_U, + bli_dgemmsup_rd_haswell_asm_6x8m_0x0_combined_U}; + +//Look-up table for Gemmt Lower Variant Kernels +gemmt_ker_ft ker_fpls[14] = +{ + bli_dgemmsup_rv_haswell_asm_6x8m_0x0_L, + bli_dgemmsup_rv_haswell_asm_6x8m_6x0_L, + bli_dgemmsup_rv_haswell_asm_6x8m_6x8_L, + bli_dgemmsup_rv_haswell_asm_6x8m_12x8_L, + bli_dgemmsup_rv_haswell_asm_6x8m_12x16_L, + bli_dgemmsup_rv_haswell_asm_6x8m_18x16_L, + bli_dgemmsup_rv_haswell_asm_6x8m_16x12_combined_L, + bli_dgemmsup_rd_haswell_asm_6x8m_0x0_L, + bli_dgemmsup_rd_haswell_asm_6x8m_6x0_L, + bli_dgemmsup_rd_haswell_asm_6x8m_6x8_L, + bli_dgemmsup_rd_haswell_asm_6x8m_12x8_L, + bli_dgemmsup_rd_haswell_asm_6x8m_12x16_L, + bli_dgemmsup_rd_haswell_asm_6x8m_18x16_L, + bli_dgemmsup_rd_haswell_asm_6x8m_16x12_combined_L +}; + // // -- var1n -------------------------------------------------------------------- // @@ -1501,7 +1558,7 @@ void PASTEMACT(ch,opname,uplo,varname) \ \ /* storage-scheme of ct should be same as that of C. Since update routines only support row-major order, - col_pref flag is used to induce transpose to matrices before + col_pref flag is used to induce transpose to matrices before passing to update routine whenever C is col-stored */ \ const bool col_pref = (rs_c == 1)? 1 : 0; \ \ @@ -1833,40 +1890,142 @@ void PASTEMACT(ch,opname,uplo,varname) \ { \ const dim_t mr_cur = (i+MR-1) < mc_cur ? MR : mc_cur - i; \ \ - /* Invoke the gemmsup millikernel. */ \ - gemmsup_ker \ - ( \ - conja, \ - conjb, \ - mr_cur, \ - nr_cur, \ - kc_cur, \ - alpha_cast, \ - a_ir, rs_a_use, cs_a_use, \ - b_jr, rs_b_use, cs_b_use, \ - zero, \ - ct, rs_ct, cs_ct, \ - &aux, \ - cntx \ - ); \ - /* Scale the bottom edge of C and add the result from above. */ \ - /* If c and ct are col-major, induce transpose and call update for upper-triangle of C */ \ - if( col_pref ) \ - { \ - PASTEMAC(ch,update_upper_triang)( n_off_cblock, m_off_cblock, \ - nr_cur, mr_cur, \ - ct, cs_ct, rs_ct, \ - beta_use, \ - c_ir, cs_c, rs_c ); \ + /* Prerequisites : MR = 6, NR = 8. + An optimization: allow the last jr iteration to contain up to NRE + In DGEMMT API implementation, kernel operates on 6x8 block. MR and + NR are set as 6 and 8 respectively. 24 being the LCM of 6 and 8, + the diagonal pattern repeats for every 24x24 block. + This pattern is exploited to achieve the optimization in diagonal + blocks by computing only the required elements. In the previous + implementation, all the 48 outputs of the given 6x8 block are + computed and stored into a temporary buffer. Later, the required + elements are copied into the final C output buffer. + With this optimization, we are avoiding copy operation and also + reducing the number of computations. + Variables m_off_24 and n_off_24 respectively store the m and n + offsets from the starting point of the corresponding 24x24 block. + Variables m_idx and n_idx store indices of the current 6x8 block + along m and n dimensions, in 24x24 block. m_idx is computed as + (m_off_24 / MR) while n_idx is computed as (n_off_24 / NR). + Range of m_idx is 0 <= m_idx <= 3 and the range of n_idx is + 0 <= n_idx <= 2. Based on these indices, for the given 6x8 block, + logic is implemented to identify the relevant kernel from the + look-up table. + During instances, where m is not a multiple of 6 or n is not a + multiple of 8, it goes to the default gemm kernel. MR and NR must be + 6 and 8 for these kernels to achieve the expected functionality.*/ \ +\ + dim_t m_off_24 = m_off_cblock % 24; \ + dim_t n_off_24 = n_off_cblock % 24; \ + dim_t m_idx = (dim_t)(m_off_24 / MR); \ + dim_t n_idx = (dim_t)(n_off_24 / NR); \ +\ + /* Check if m, n indices are multiple of MR and NR respectively + and current block is a complete 6x8 block */ \ + bool idx_supported = ((m_off_24 % MR) == 0) && ((n_off_24 % NR) == 0) && (mr_cur == MR) && (nr_cur == NR); \ +\ + /* m_idx and n_idx would be equal only if the current block is + a diagonal block */\ + if( (dt == BLIS_DOUBLE) && (m_idx == n_idx) && (idx_supported) ) { \ + /* index of kernel in lookup table is 2*m_idx) */ \ + dim_t ker_idx; \ + ker_idx = m_idx<<1; \ +\ + /* If there is another 6x8 diagonal block pending for computation + after the current 6x8 diagonal block, then the two blocks can + be computed together(12x8). This combined kernel is implemented + only for the case where n_idx = 2 i.e., n_off_24 = 16. To call + this, it has to be ensured that at least 12 rows are pending in + C for computation. (m_off + 2 * MR <=m). Usage of this combined + kernel saves the entire time to execute one kernel*/ \ + if( (n_idx == 2) && (m_off_cblock + MR + MR <= m) ) {\ + ker_idx = 6; /* use combined kernel, index of combined kernel + in lookup table is 6 */\ + } \ + /* use rd kernel if B is column major storage */ \ + if( stor_id == BLIS_RRC ) { \ + ker_idx += 7; /* index of rd kernel*/ \ + } \ + gemmt_ker_ft ker_fp = ker_fpls[ker_idx]; \ + ker_fp \ + ( \ + conja, \ + conjb, \ + mr_cur, \ + nr_cur, \ + kc_cur, \ + (double*) alpha_cast, \ + (double*) a_ir, rs_a_use, cs_a_use, \ + (double*) b_jr, rs_b_use, cs_b_use, \ + (double*) beta_use, \ + (double*) c_ir, rs_c, cs_c, \ + &aux, \ + cntx \ + ); \ } \ - else \ - { \ - PASTEMAC(ch,update_lower_triang)( m_off_cblock, n_off_cblock, \ - mr_cur, nr_cur, \ - ct, rs_ct, cs_ct, \ - beta_use, \ - c_ir, rs_c, cs_c ); \ + /* 6x8 block where m_idx == n_idx+1 also has some parts of the diagonal */\ + else if( (dt == BLIS_DOUBLE) && (m_idx == n_idx+1) && (idx_supported) ) { \ + /* If current block was already computed in the combined kernel it + can be skipped combined kernel is only implemented for n_idx=2, + i == m_zero is only true for the first iteration therefore if + i == m_zero then the current 6x8 block was not computed in + combined kernel*/ \ + if( (n_idx != 2) || (i == m_zero) ) { \ + dim_t ker_idx = (n_idx << 1) + 1; \ + /* use rd kernel if B is column major storage */ \ + if( stor_id == BLIS_RRC ) { ker_idx += 7; } \ + gemmt_ker_ft ker_fp = ker_fpls[ker_idx]; \ + ker_fp \ + ( \ + conja, \ + conjb, \ + mr_cur, \ + nr_cur, \ + kc_cur, \ + (double*) alpha_cast, \ + (double*) a_ir, rs_a_use, cs_a_use, \ + (double*) b_jr, rs_b_use, cs_b_use, \ + (double*) beta_use, \ + (double*) c_ir, rs_c, cs_c, \ + &aux, \ + cntx \ + ); \ + } \ } \ + /* Call the regular kernel for non applicable cases */ \ + else { \ + gemmsup_ker \ + ( \ + conja, \ + conjb, \ + mr_cur, \ + nr_cur, \ + kc_cur, \ + alpha_cast, \ + a_ir, rs_a_use, cs_a_use, \ + b_jr, rs_b_use, cs_b_use, \ + zero, \ + ct, rs_ct, cs_ct, \ + &aux, \ + cntx \ + ); \ + if( col_pref ) \ + { \ + PASTEMAC(ch,update_upper_triang)( n_off_cblock, m_off_cblock, \ + nr_cur, mr_cur, \ + ct, cs_ct, rs_ct, \ + beta_use, \ + c_ir, cs_c, rs_c ); \ + } \ + else \ + { \ + PASTEMAC(ch,update_lower_triang)( m_off_cblock, n_off_cblock, \ + mr_cur, nr_cur, \ + ct, rs_ct, cs_ct, \ + beta_use, \ + c_ir, rs_c, cs_c ); \ + }\ + }\ \ a_ir += ps_a_use; \ c_ir += irstep_c; \ @@ -2410,39 +2569,140 @@ void PASTEMACT(ch,opname,uplo,varname) \ { \ const dim_t mr_cur = (i+MR-1) < mc_cur ? MR : mc_cur - i; \ \ - /* Invoke the gemmsup millikernel. */ \ - gemmsup_ker \ - ( \ - conja, \ - conjb, \ - mr_cur, \ - nr_cur, \ - kc_cur, \ - alpha_cast, \ - a_ir, rs_a_use, cs_a_use, \ - b_jr, rs_b_use, cs_b_use, \ - zero, \ - ct, rs_ct, cs_ct, \ - &aux, \ - cntx \ - ); \ -\ - if( col_pref ) \ - { \ - PASTEMAC(ch,update_lower_triang)( n_off_cblock, m_off_cblock, \ - nr_cur, mr_cur, \ - ct, cs_ct, rs_ct, \ - beta_use, \ - c_ir, cs_c, rs_c ); \ + /* Prerequisites : MR = 6, NR = 8. + An optimization: allow the last jr iteration to contain up to NRE + In DGEMMT API implementation, kernel operates on 6x8 block. MR and + NR are set as 6 and 8 respectively. 24 being the LCM of 6 and 8, + the diagonal pattern repeats for every 24x24 block. + This pattern is exploited to achieve the optimization in diagonal + blocks by computing only the required elements. In the previous + implementation, all the 48 outputs of the given 6x8 block are + computed and stored into a temporary buffer. Later, the required + elements are copied into the final C output buffer. + With this optimization, we are avoiding copy operation and also + reducing the number of computations. + Variables m_off_24 and n_off_24 respectively store the m and n + offsets from the starting point of the corresponding 24x24 block. + Variables m_idx and n_idx store indices of the current 6x8 block + along m and n dimensions, in 24x24 block. m_idx is computed as + (m_off_24 / MR) while n_idx is computed as (n_off_24 / NR). + Range of m_idx is 0 <= m_idx <= 3 and the range of n_idx is + 0 <= n_idx <= 2. Based on these indices, for the given 6x8 block, + logic is implemented to identify the relevant kernel from the + look-up table. + During instances, where m is not a multiple of 6 or n is not a + multiple of 8, it goes to the default gemm kernel. MR and NR must be + 6 and 8 for these kernels to achieve the expected functionality.*/ \ + dim_t m_off_24 = m_off_cblock % 24; \ + dim_t n_off_24 = n_off_cblock % 24; \ + dim_t m_idx = (dim_t)(m_off_24 / MR); \ + dim_t n_idx = (dim_t)(n_off_24 / NR); \ +\ + /* Check if m, n indices are multiple of MR and NR respectively + and current block is a complete 6x8 block */ \ + bool idx_supported = ((m_off_24 % MR) == 0) && ((n_off_24 % NR) == 0) && (mr_cur==MR) && (nr_cur==NR); \ +\ + /* m_idx and n_idx would be equal only if the current block is + a diagonal block */\ + if( (dt == BLIS_DOUBLE) && (m_idx == n_idx) && idx_supported ) { \ + dim_t ker_idx = m_idx<<1; \ + /* If there is another 6x8 diagonal block pending for computation + after the current 6x8 diagonal block, then the two blocks can + be computed together(12x8). This combined kernel is implemented + only for the case where n_idx = 0 i.e., n_off_24 = 0. To call + this, it has to be ensured that at least 12 rows are pending in + C for computation (i+ MR + MR <= mc_cur). Usage of this combined + kernel saves the entire time to execute one kernel*/ \ + if( (n_idx == 0) && (i+ MR + MR <= mc_cur) ) { \ + ker_idx = 6; /* use combined kernel, index of combined kernel + in lookup table is 6 */\ + } \ + /* if B is column storage we use rd kernel*/ \ + if( stor_id == BLIS_RRC ) { \ + ker_idx += 7; /* index of rd kernel*/\ + } \ + gemmt_ker_ft ker_fp = ker_fpus[ker_idx]; \ + ker_fp \ + ( \ + conja, \ + conjb, \ + mr_cur, \ + nr_cur, \ + kc_cur, \ + (double*) alpha_cast, \ + (double*) a_ir, rs_a_use, cs_a_use, \ + (double*) b_jr, rs_b_use, cs_b_use, \ + (double*) beta_use, \ + (double*) c_ir, rs_c, cs_c, \ + &aux, \ + cntx \ + ); \ } \ - else \ - { \ - PASTEMAC(ch,update_upper_triang)( m_off_cblock, n_off_cblock, \ - mr_cur, nr_cur, \ - ct, rs_ct, cs_ct, \ - beta_use, \ - c_ir, rs_c, cs_c ); \ + /* 6x8 block where m_idx == n_idx+1 also has some parts of the diagonal */\ + else if( (dt == BLIS_DOUBLE) && (m_idx == n_idx+1) && (idx_supported) ) { \ + /* If current block was already computed in the combined kernel it + can be skipped combined kernel is only implemented for n_idx=0, + i == m_rect is only true for the first iteration therefore if + i == m_rect then the current 6x8 block was not computed in + combined kernel*/ \ + if( (n_idx != 0) || (i == m_rect) ) { \ + dim_t ker_idx = (n_idx << 1) + 1 ; \ + /* use rd kernel if B is column major storage */ \ + if( stor_id == BLIS_RRC ) { ker_idx += 7; } \ + gemmt_ker_ft ker_fp = ker_fpus[ker_idx]; \ + ker_fp \ + ( \ + conja, \ + conjb, \ + mr_cur, \ + nr_cur, \ + kc_cur, \ + (double*) alpha_cast, \ + (double*) a_ir, rs_a_use, cs_a_use, \ + (double*) b_jr, rs_b_use, cs_b_use, \ + (double*) beta_use, \ + (double*) c_ir, rs_c, cs_c, \ + &aux, \ + cntx \ + ); \ + } \ } \ + /* call the regular kernel for non applicable cases */ \ + else { \ + gemmsup_ker \ + ( \ + conja, \ + conjb, \ + mr_cur, \ + nr_cur, \ + kc_cur, \ + alpha_cast, \ + a_ir, rs_a_use, cs_a_use, \ + b_jr, rs_b_use, cs_b_use, \ + zero, \ + ct, rs_ct, cs_ct, \ + &aux, \ + cntx \ + ); \ + \ + if( col_pref ) \ + { \ + PASTEMAC(ch,update_lower_triang)( n_off_cblock, m_off_cblock, \ + nr_cur, mr_cur, \ + ct, cs_ct, rs_ct, \ + beta_use, \ + c_ir, cs_c, rs_c ); \ + } \ + else \ + { \ + PASTEMAC(ch,update_upper_triang)( m_off_cblock, n_off_cblock, \ + mr_cur, nr_cur, \ + ct, rs_ct, cs_ct, \ + beta_use, \ + c_ir, rs_c, cs_c ); \ + } \ + } \ +\ a_ir += ps_a_use; \ c_ir += irstep_c; \ m_off_cblock += mr_cur; \ diff --git a/frame/3/trmm/CMakeLists.txt b/frame/3/trmm/CMakeLists.txt index a3845f3858..49106e4b10 100644 --- a/frame/3/trmm/CMakeLists.txt +++ b/frame/3/trmm/CMakeLists.txt @@ -12,6 +12,7 @@ target_sources("${PROJECT_NAME}" if(${TARGET_ARCH} STREQUAL zen OR ${TARGET_ARCH} STREQUAL zen2 OR ${TARGET_ARCH} STREQUAL zen3 OR +${TARGET_ARCH} STREQUAL zen4 OR ${TARGET_ARCH} STREQUAL amdzen) target_sources("${PROJECT_NAME}" PRIVATE diff --git a/frame/3/trsm/bli_trsm_front.c b/frame/3/trsm/bli_trsm_front.c index f964faf0dd..35cd2d4b85 100644 --- a/frame/3/trsm/bli_trsm_front.c +++ b/frame/3/trsm/bli_trsm_front.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2020, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2018 - 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -36,6 +36,7 @@ #include "blis.h" //#define PRINT_SMALL_TRSM_INFO + void bli_trsm_front ( side_t side, @@ -151,6 +152,24 @@ void bli_trsm_front // in bli_packm_init(). if ( bli_cntx_method( cntx ) == BLIS_NAT ) { +#if defined(BLIS_FAMILY_AMDZEN) || defined(BLIS_FAMILY_ZEN4) + /* Zen4 TRSM Fixme: + * + * On Zen4 we want to use AVX-512 kernels for GEMM and AVX2 kernels + * for TRSM (Till we implemente TRSM AVX-512 kernels) + * + * The AVX2 kernels use different block sizes then AVX512 kernels + * Here we override the default block sizes in the context with AVX2 + * specific block size used in GEMMTRSM kernerls. + * + * We need to revisit this when TRSM AVX-512 kernels are implemented. + */ + if ( (bli_arch_query_id() == BLIS_ARCH_ZEN4) && + (bli_obj_dt(a) == BLIS_FLOAT) ) + { + bli_zen4_override_trsm_blkszs(cntx); + } +#endif bli_obj_set_pack_schema( BLIS_PACKED_ROW_PANELS, &a_local ); bli_obj_set_pack_schema( BLIS_PACKED_COL_PANELS, &b_local ); } @@ -177,6 +196,20 @@ void bli_trsm_front rntm, cntl ); + +#if defined(BLIS_FAMILY_AMDZEN) || defined(BLIS_FAMILY_ZEN4) + /* Zen4 TRSM Fixme: + * + * We have overrding the block sizes at the start of this function + * Since the context is created only once we need to ensure that the + * default block sizes are restored for the subsequent operations. + */ + if ( (bli_arch_query_id() == BLIS_ARCH_ZEN4) && + (bli_obj_dt(a) == BLIS_FLOAT) ) + { + bli_zen4_restore_default_blkszs(cntx); + } +#endif AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3); } diff --git a/frame/3/trsm/bli_trsm_ll_ker_var2.c b/frame/3/trsm/bli_trsm_ll_ker_var2.c index 5426348c83..a15f39fc3c 100644 --- a/frame/3/trsm/bli_trsm_ll_ker_var2.c +++ b/frame/3/trsm/bli_trsm_ll_ker_var2.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2020, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2018 - 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -174,6 +174,25 @@ void PASTEMAC(ch,varname) \ gemmtrsm_ukr = bli_cntx_get_l3_vir_ukr_dt( dt, BLIS_GEMMTRSM_L_UKR, cntx ); \ PASTECH(ch,gemm_ukr_ft) \ gemm_ukr = bli_cntx_get_l3_vir_ukr_dt( dt, BLIS_GEMM_UKR, cntx ); \ +\ + /* Zen4 TRSM Fixme: + * + * On Zen4 we want to use AVX-512 kernels for GEMM and AVX2 kernels + * for TRSM (Till we implemente TRSM AVX-512 kernels) + * + * The AVX2 kernels for TRSM are enabled in the context, but they + * are compatible with only AVX2 version of GEMM kernels. + * + * Here we force the GEMM kernels to the AVX2 varients for float and double. + * For scomplex and dcomplex reference path is retained as is. + * + * We need to revisit this when TRSM AVX-512 kernels are implemented. + */ \ + if ((bli_arch_query_id() == BLIS_ARCH_ZEN4) && \ + (dt == BLIS_FLOAT)) \ + { \ + gemm_ukr = bli_cntx_get_l3_vir_ukr_dt( dt, BLIS_GEMM_AVX2_UKR, cntx ); \ + } \ \ /* Temporary C buffer for edge cases. Note that the strides of this temporary buffer are set so that they match the storage of the diff --git a/frame/3/trsm/bli_trsm_lu_ker_var2.c b/frame/3/trsm/bli_trsm_lu_ker_var2.c index 0d4e2e0ba6..48e4588f52 100644 --- a/frame/3/trsm/bli_trsm_lu_ker_var2.c +++ b/frame/3/trsm/bli_trsm_lu_ker_var2.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2020, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2018 - 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -174,6 +174,25 @@ void PASTEMAC(ch,varname) \ gemmtrsm_ukr = bli_cntx_get_l3_vir_ukr_dt( dt, BLIS_GEMMTRSM_U_UKR, cntx ); \ PASTECH(ch,gemm_ukr_ft) \ gemm_ukr = bli_cntx_get_l3_vir_ukr_dt( dt, BLIS_GEMM_UKR, cntx ); \ +\ + /* Zen4 TRSM Fixme: + * + * On Zen4 we want to use AVX-512 kernels for GEMM and AVX2 kernels + * for TRSM (Till we implemente TRSM AVX-512 kernels) + * + * The AVX2 kernels for TRSM are enabled in the context, but they + * are compatible with only AVX2 version of GEMM kernels. + * + * Here we force the GEMM kernels to the AVX2 varients for float and double. + * For scomplex and dcomplex reference path is retained as is. + * + * We need to revisit this when TRSM AVX-512 kernels are implemented. + */ \ + if ((bli_arch_query_id() == BLIS_ARCH_ZEN4) && \ + (dt == BLIS_FLOAT)) \ + { \ + gemm_ukr = bli_cntx_get_l3_vir_ukr_dt( dt, BLIS_GEMM_AVX2_UKR, cntx ); \ + } \ \ /* Temporary C buffer for edge cases. Note that the strides of this temporary buffer are set so that they match the storage of the diff --git a/frame/3/trsm/bli_trsm_rl_ker_var2.c b/frame/3/trsm/bli_trsm_rl_ker_var2.c index 396fb4af12..2705a747ac 100644 --- a/frame/3/trsm/bli_trsm_rl_ker_var2.c +++ b/frame/3/trsm/bli_trsm_rl_ker_var2.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2020, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2018 - 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -180,6 +180,25 @@ void PASTEMAC(ch,varname) \ gemmtrsm_ukr = bli_cntx_get_l3_vir_ukr_dt( dt, BLIS_GEMMTRSM_U_UKR, cntx ); \ PASTECH(ch,gemm_ukr_ft) \ gemm_ukr = bli_cntx_get_l3_vir_ukr_dt( dt, BLIS_GEMM_UKR, cntx ); \ +\ + /* Zen4 TRSM Fixme: + * + * On Zen4 we want to use AVX-512 kernels for GEMM and AVX2 kernels + * for TRSM (Till we implemente TRSM AVX-512 kernels) + * + * The AVX2 kernels for TRSM are enabled in the context, but they + * are compatible with only AVX2 version of GEMM kernels. + * + * Here we force the GEMM kernels to the AVX2 varients for float and double. + * For scomplex and dcomplex reference path is retained as is. + * + * We need to revisit this when TRSM AVX-512 kernels are implemented. + */ \ + if ((bli_arch_query_id() == BLIS_ARCH_ZEN4) && \ + (dt == BLIS_FLOAT)) \ + { \ + gemm_ukr = bli_cntx_get_l3_vir_ukr_dt( dt, BLIS_GEMM_AVX2_UKR, cntx ); \ + } \ \ /* Temporary C buffer for edge cases. Note that the strides of this temporary buffer are set so that they match the storage of the diff --git a/frame/3/trsm/bli_trsm_ru_ker_var2.c b/frame/3/trsm/bli_trsm_ru_ker_var2.c index 8b73b702f0..dc37614eb6 100644 --- a/frame/3/trsm/bli_trsm_ru_ker_var2.c +++ b/frame/3/trsm/bli_trsm_ru_ker_var2.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2020, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2018 - 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -179,6 +179,25 @@ void PASTEMAC(ch,varname) \ gemmtrsm_ukr = bli_cntx_get_l3_vir_ukr_dt( dt, BLIS_GEMMTRSM_L_UKR, cntx ); \ PASTECH(ch,gemm_ukr_ft) \ gemm_ukr = bli_cntx_get_l3_vir_ukr_dt( dt, BLIS_GEMM_UKR, cntx ); \ +\ + /* Zen4 TRSM Fixme: + * + * On Zen4 we want to use AVX-512 kernels for GEMM and AVX2 kernels + * for TRSM (Till we implemente TRSM AVX-512 kernels) + * + * The AVX2 kernels for TRSM are enabled in the context, but they + * are compatible with only AVX2 version of GEMM kernels. + * + * Here we force the GEMM kernels to the AVX2 varients for float and double. + * For scomplex and dcomplex reference path is retained as is. + * + * We need to revisit this when TRSM AVX-512 kernels are implemented. + */ \ + if ((bli_arch_query_id() == BLIS_ARCH_ZEN4) && \ + (dt == BLIS_FLOAT)) \ + { \ + gemm_ukr = bli_cntx_get_l3_vir_ukr_dt( dt, BLIS_GEMM_AVX2_UKR, cntx ); \ + } \ \ /* Temporary C buffer for edge cases. Note that the strides of this temporary buffer are set so that they match the storage of the diff --git a/frame/base/bli_arch.c b/frame/base/bli_arch.c index 153787d3ed..fecc353161 100644 --- a/frame/base/bli_arch.c +++ b/frame/base/bli_arch.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018-2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -81,9 +81,19 @@ void bli_arch_set_id( void ) bool do_logging = bli_env_get_var( "BLIS_ARCH_DEBUG", 0 ); bli_arch_set_logging( do_logging ); - // Check the environment variable BLIS_ARCH_TYPE to see if the user - // requested that we use a specific subconfiguration. - dim_t req_id = bli_env_get_var( "BLIS_ARCH_TYPE", -1 ); + // DISABLE_BLIS_ARCH_TYPE and BLIS_CONFIGURETIME_CPUID seem similar but + // have different use cases: + // * BLIS_CONFIGURETIME_CPUID is used by the "configure auto" option to + // select a single code path, and affects other parts of the code. + // * DISABLE_BLIS_ARCH_TYPE disables user selection of code path here in + // builds with multiple code paths. + +#ifndef DISABLE_BLIS_ARCH_TYPE + // Check the environment variable (that "__blis_arch_type_name" is + // defined to be) to see if the user requested that we use a specific + // subconfiguration. "__blis_arch_type_name" will be defined by the + // configure command in bli_config.h, with the default name of BLIS_ARCH_TYPE + dim_t req_id = bli_env_get_var_arch_type( __blis_arch_type_name, -1 ); #ifndef BLIS_CONFIGURETIME_CPUID if ( req_id != -1 ) @@ -118,6 +128,8 @@ void bli_arch_set_id( void ) id = req_id; } else +#endif + #endif { // BLIS_ARCH_TYPE was unset. Proceed with normal subconfiguration @@ -154,6 +166,9 @@ void bli_arch_set_id( void ) #endif // AMD microarchitectures. + #ifdef BLIS_FAMILY_ZEN4 + id = BLIS_ARCH_ZEN4; + #endif #ifdef BLIS_FAMILY_ZEN3 id = BLIS_ARCH_ZEN3; #endif @@ -227,8 +242,14 @@ void bli_arch_set_id( void ) // enumeration that is typedef'ed in bli_type_defs.h. That is, the // index order of each string should correspond to the implied/assigned // enum value given to the corresponding BLIS_ARCH_ value. +// This must also be kept up-to-date with the bli_env_get_var_arch_type() +// function in bli_env.c static char* config_name[ BLIS_NUM_ARCHS ] = { + "error", + + "generic", + "skx", "knl", "knc", @@ -236,6 +257,7 @@ static char* config_name[ BLIS_NUM_ARCHS ] = "sandybridge", "penryn", + "zen4", "zen3", "zen2", "zen", diff --git a/frame/base/bli_blksz.c b/frame/base/bli_blksz.c index f3891dbbba..a4d937bc2f 100644 --- a/frame/base/bli_blksz.c +++ b/frame/base/bli_blksz.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020 - 2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020 - 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -272,7 +272,7 @@ dim_t bli_determine_blocksize_f b_alg = bli_blksz_get_def( dt, bsize ); b_max = bli_blksz_get_max( dt, bsize ); - // If b_use != 0, this means that trsm blocksizes are set + // If b_alg != 0, this means that trsm blocksizes are set // and we continue with trsm-specific blocksizes. // Else, we query L3 blocksizes and use them for TRSM execution. if( b_alg > 0 ) return bli_determine_blocksize_f_sub( i, dim, b_alg, b_max); @@ -313,10 +313,10 @@ dim_t bli_determine_blocksize_b b_alg = bli_blksz_get_def( dt, bsize ); b_max = bli_blksz_get_max( dt, bsize ); - // If b_use != 0, this means that trsm blocksizes are set + // If b_alg != 0, this means that trsm blocksizes are set // and we continue with trsm-specific blocksizes. // Else, we query L3 blocksizes and use them for TRSM execution. - if( b_alg > 0 ) bli_determine_blocksize_b_sub( i, dim, b_alg, b_max ); + if( b_alg > 0 ) return bli_determine_blocksize_b_sub( i, dim, b_alg, b_max ); } diff --git a/frame/base/bli_cntx.h b/frame/base/bli_cntx.h index d868167234..8dab2a5a19 100644 --- a/frame/base/bli_cntx.h +++ b/frame/base/bli_cntx.h @@ -6,7 +6,7 @@ Copyright (C) 2014, The University of Texas at Austin Copyright (C) 2016, Hewlett Packard Enterprise Development LP - Copyright (C) 2020 - 21, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020 - 22, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -601,6 +601,27 @@ BLIS_INLINE bool bli_cntx_l3_vir_ukr_dislikes_storage_of( obj_t* obj, l3ukr_t uk !bli_cntx_l3_vir_ukr_prefers_storage_of( obj, ukr_id, cntx ); } +BLIS_INLINE bool bli_cntx_l3_vir_ukr_prefers_storage_of_md( obj_t* obj, num_t dt, l3ukr_t ukr_id, cntx_t* cntx ) +{ + // we use the computation datatype, which may differ from the + // storage datatype of C + const bool ukr_prefers_rows + = bli_cntx_l3_vir_ukr_prefers_rows_dt( dt, ukr_id, cntx ); + const bool ukr_prefers_cols + = bli_cntx_l3_vir_ukr_prefers_cols_dt( dt, ukr_id, cntx ); + bool r_val = FALSE; + + if ( bli_obj_is_row_stored( obj ) && ukr_prefers_rows ) r_val = TRUE; + else if ( bli_obj_is_col_stored( obj ) && ukr_prefers_cols ) r_val = TRUE; + return r_val; +} + +BLIS_INLINE bool bli_cntx_l3_vir_ukr_dislikes_storage_of_md( obj_t* obj, num_t dt, l3ukr_t ukr_id, cntx_t* cntx ) +{ + return ( bool ) + !bli_cntx_l3_vir_ukr_prefers_storage_of_md( obj, dt, ukr_id, cntx ); +} + // ----------------------------------------------------------------------------- BLIS_INLINE bool bli_cntx_l3_sup_thresh_is_met( obj_t* a, obj_t* b, obj_t* c, cntx_t* cntx ) { @@ -621,6 +642,30 @@ BLIS_INLINE bool bli_cntx_l3_sup_thresh_is_met( obj_t* a, obj_t* b, obj_t* c, cn } + + if(dt == BLIS_DOUBLE) + { + /** + * In case of both matrices having large strides, + * are to be handled in native path, since native + * path does packing of both matrices by default. + * It helps avoiding huge memory jumps while accessing + * matrices during GEMM computation. + */ + dim_t k = bli_obj_width( a ); + inc_t rs_a = bli_obj_row_stride( a ); + inc_t cs_a = bli_obj_col_stride( a ); + inc_t rs_b = bli_obj_row_stride( b ); + inc_t cs_b = bli_obj_col_stride( b ); + inc_t stride_a = rs_a > cs_a ? rs_a : cs_a; + inc_t stride_b = rs_b > cs_b ? rs_b : cs_b; + if( (m > 5000 && n > 700 && k > 120) && (stride_a > 5000 && stride_b > 5000) ) + { + return FALSE; + } + } + + if ( m < bli_cntx_get_l3_sup_thresh_dt( dt, BLIS_MT, cntx ) ) return TRUE; if ( n < bli_cntx_get_l3_sup_thresh_dt( dt, BLIS_NT, cntx ) ) return TRUE; if ( k < bli_cntx_get_l3_sup_thresh_dt( dt, BLIS_KT, cntx ) ) return TRUE; diff --git a/frame/base/bli_cpuid.c b/frame/base/bli_cpuid.c index d10ea1039a..2796fadc05 100644 --- a/frame/base/bli_cpuid.c +++ b/frame/base/bli_cpuid.c @@ -114,6 +114,10 @@ arch_t bli_cpuid_query_id( void ) // Check for each AMD configuration that is enabled, check for that // microarchitecture. We check from most recent to most dated. +#ifdef BLIS_CONFIG_ZEN4 + if ( bli_cpuid_is_zen4( family, model, features ) ) + return BLIS_ARCH_ZEN4; +#endif #ifdef BLIS_CONFIG_ZEN3 if ( bli_cpuid_is_zen3( family, model, features ) ) return BLIS_ARCH_ZEN3; @@ -264,6 +268,41 @@ bool bli_cpuid_is_penryn } // ----------------------------------------------------------------------------- +bool bli_cpuid_is_zen4 + ( + uint32_t family, + uint32_t model, + uint32_t features + ) +{ + // Check for expected CPU features. + const uint32_t expected = FEATURE_SSE3 | + FEATURE_SSSE3 | + FEATURE_SSE41 | + FEATURE_SSE42 | + FEATURE_AVX | + FEATURE_AVX2 | + FEATURE_FMA3 | + FEATURE_AVX512F | + FEATURE_AVX512DQ | + FEATURE_AVX512CD | + FEATURE_AVX512BW | + FEATURE_AVX512VL | + FEATURE_AVX512VNNI; + + if ( !bli_cpuid_has_features( features, expected ) ) return FALSE; + + // For zen4 the family id is 0x19 + if ( family != 0x19 ) return FALSE; + + // All family 0x19 CPUs that support AVX512 instructions are zen4, + // thus no need to check model numbers here. Family 0x19 CPUs that + // don't support AVX512 are zen3. Their model ranges are tested in + // a separate function below. + + return TRUE; +} + bool bli_cpuid_is_zen3 ( uint32_t family, @@ -516,6 +555,114 @@ bool bli_cpuid_is_avx_supported( void ) return is_avx_supported; } + +// Check (at runtime) if AVX512_VNNI is supported on the current platform, this +// is to ensure that AVX512_VNNI kernels are not used on legacy platforms which +// results in crash. + +// The support for AVX512_VNNI is checked only once (when this API is called +// first time). On subsequent calls the cached value is returned. +static bool is_avx512vnni_supported = FALSE; + +// Determine if the CPU has support for AVX512_VNNI. +void bli_cpuid_check_avx512vnni_support( void ) +{ + uint32_t family, model, features; + + // Call the CPUID instruction and parse its results into a family id, + // model id, and a feature bit field. + bli_cpuid_query( &family, &model, &features ); + + // Check for expected CPU features. + const uint32_t expected = FEATURE_AVX | + FEATURE_FMA3 | + FEATURE_AVX2 | + FEATURE_AVX512F | + FEATURE_AVX512DQ | + FEATURE_AVX512BW | + FEATURE_AVX512VL | + FEATURE_AVX512VNNI; + + if ( !bli_cpuid_has_features( features, expected ) ) + { + is_avx512vnni_supported = FALSE; + } + else + { + is_avx512vnni_supported = TRUE; + } +} + +// The support for AVX512_BF16 is checked only once (when this API is called +// first time). On subsequent calls the cached value is returned. +static bool is_avx512bf16_supported = FALSE; + +// Determine if the CPU has support for AVX512_BF16. +void bli_cpuid_check_avx512_bf16_support( void ) +{ + uint32_t family, model, features; + + // Call the CPUID instruction and parse its results into a family id, + // model id, and a feature bit field. + bli_cpuid_query( &family, &model, &features ); + + // Check for expected CPU features. + const uint32_t expected = FEATURE_AVX | + FEATURE_FMA3 | + FEATURE_AVX2 | + FEATURE_AVX512F | + FEATURE_AVX512DQ | + FEATURE_AVX512BW | + FEATURE_AVX512VL | + FEATURE_AVX512VNNI | + FEATURE_AVX512BF16 + ; + + if ( !bli_cpuid_has_features( features, expected ) ) + { + is_avx512bf16_supported = FALSE; + } + else + { + is_avx512bf16_supported = TRUE; + } +} + +static bli_pthread_once_t once_check_avx512vnni_support = BLIS_PTHREAD_ONCE_INIT; +static bli_pthread_once_t once_check_avx512_bf16_support = BLIS_PTHREAD_ONCE_INIT; + +// Ensure that actual support determination happens only once +void bli_cpuid_check_avx512vnni_support_once( void ) +{ +#ifndef BLIS_CONFIGURETIME_CPUID + bli_pthread_once( &once_check_avx512vnni_support, bli_cpuid_check_avx512vnni_support ); +#endif +} + +// Ensure that actual support determination happens only once to avoid performance hit +void bli_cpuid_check_avx512_bf16_support_once( void ) +{ +#ifndef BLIS_CONFIGURETIME_CPUID + bli_pthread_once( &once_check_avx512_bf16_support, bli_cpuid_check_avx512_bf16_support ); +#endif +} + +// API to check if AVX512_VNNI is supported or not on the current platform. +bool bli_cpuid_is_avx512vnni_supported( void ) +{ + bli_cpuid_check_avx512vnni_support_once(); + + return is_avx512vnni_supported; +} + +// API to check if AVX512_bf16 is supported or not on the current platform. +bool bli_cpuid_is_avx512_bf16_supported( void ) +{ + bli_cpuid_check_avx512_bf16_support_once(); + + return is_avx512bf16_supported; +} + #elif defined(__aarch64__) || defined(__arm__) || defined(_M_ARM) arch_t bli_cpuid_query_id( void ) @@ -716,6 +863,8 @@ enum FEATURE_MASK_AVX512CD = (1u<<28), // cpuid[eax=7,ecx=0] :ebx[28] FEATURE_MASK_AVX512BW = (1u<<30), // cpuid[eax=7,ecx=0] :ebx[30] FEATURE_MASK_AVX512VL = (1u<<31), // cpuid[eax=7,ecx=0] :ebx[31] + FEATURE_MASK_AVX512VNNI = (1u<<11), // cpuid[eax=7,ecx=0] :ecx[11] + FEATURE_MASK_AVX512BF16 = (1u<< 5), // cpuid[eax=7,ecx=1] :eax[5] FEATURE_MASK_XGETBV = (1u<<26)| (1u<<27), // cpuid[eax=1] :ecx[27:26] XGETBV_MASK_XMM = 0x02u, // xcr0[1] @@ -782,6 +931,18 @@ uint32_t bli_cpuid_query if ( bli_cpuid_has_features( ebx, FEATURE_MASK_AVX512CD ) ) *features |= FEATURE_AVX512CD; if ( bli_cpuid_has_features( ebx, FEATURE_MASK_AVX512BW ) ) *features |= FEATURE_AVX512BW; if ( bli_cpuid_has_features( ebx, FEATURE_MASK_AVX512VL ) ) *features |= FEATURE_AVX512VL; + + if ( bli_cpuid_has_features( ecx, FEATURE_MASK_AVX512VNNI ) ) *features |= FEATURE_AVX512VNNI; + + // This is actually a macro that modifies the last four operands, + // hence why they are not passed by address. + // This returns extended feature flags in EAX. + // The availability of AVX512_BF16 can be found using the + // 5th feature bit of the returned value + __cpuid_count( 7, 1, eax, ebx, ecx, edx ); + + if ( bli_cpuid_has_features( eax, FEATURE_MASK_AVX512BF16 ) ) *features |= FEATURE_AVX512BF16; + } // Check extended processor info / features bits for AMD-specific features. diff --git a/frame/base/bli_cpuid.h b/frame/base/bli_cpuid.h index 47b584c883..805f31bf2e 100644 --- a/frame/base/bli_cpuid.h +++ b/frame/base/bli_cpuid.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018-2020, Advanced Micro Devices, Inc. + Copyright (C) 2018-2022, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -61,6 +61,7 @@ bool bli_cpuid_is_sandybridge( uint32_t family, uint32_t model, uint32_t feature bool bli_cpuid_is_penryn( uint32_t family, uint32_t model, uint32_t features ); // AMD +bool bli_cpuid_is_zen4( uint32_t family, uint32_t model, uint32_t features ); bool bli_cpuid_is_zen3( uint32_t family, uint32_t model, uint32_t features ); bool bli_cpuid_is_zen2( uint32_t family, uint32_t model, uint32_t features ); bool bli_cpuid_is_zen( uint32_t family, uint32_t model, uint32_t features ); @@ -133,6 +134,8 @@ BLIS_INLINE bool bli_cpuid_has_features( uint32_t have, uint32_t want ) void get_cpu_name( char *cpu_name ); int vpu_count( void ); bool bli_cpuid_is_avx_supported(void); +bool bli_cpuid_is_avx512vnni_supported(void); +bool bli_cpuid_is_avx512_bf16_supported(void); enum { @@ -142,25 +145,25 @@ enum }; enum { - FEATURE_SSE3 = 0x0001, - FEATURE_SSSE3 = 0x0002, - FEATURE_SSE41 = 0x0004, - FEATURE_SSE42 = 0x0008, - FEATURE_AVX = 0x0010, - FEATURE_AVX2 = 0x0020, - FEATURE_FMA3 = 0x0040, - FEATURE_FMA4 = 0x0080, - FEATURE_AVX512F = 0x0100, - FEATURE_AVX512DQ = 0x0200, - FEATURE_AVX512PF = 0x0400, - FEATURE_AVX512ER = 0x0800, - FEATURE_AVX512CD = 0x1000, - FEATURE_AVX512BW = 0x2000, - FEATURE_AVX512VL = 0x4000 + FEATURE_SSE3 = 0x0001, + FEATURE_SSSE3 = 0x0002, + FEATURE_SSE41 = 0x0004, + FEATURE_SSE42 = 0x0008, + FEATURE_AVX = 0x0010, + FEATURE_AVX2 = 0x0020, + FEATURE_FMA3 = 0x0040, + FEATURE_FMA4 = 0x0080, + FEATURE_AVX512F = 0x0100, + FEATURE_AVX512DQ = 0x0200, + FEATURE_AVX512PF = 0x0400, + FEATURE_AVX512ER = 0x0800, + FEATURE_AVX512CD = 0x1000, + FEATURE_AVX512BW = 0x2000, + FEATURE_AVX512VL = 0x4000, + FEATURE_AVX512VNNI = 0x8000, + FEATURE_AVX512BF16 = 0x10000 }; - - #elif defined(__aarch64__) || defined(__arm__) || defined(_M_ARM) char* find_string_in( char* target, char* buffer, size_t buf_len, char* filepath ); diff --git a/frame/base/bli_env.c b/frame/base/bli_env.c index 23b8e059e1..7fabc2b955 100644 --- a/frame/base/bli_env.c +++ b/frame/base/bli_env.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -69,6 +69,157 @@ gint_t bli_env_get_var( const char* env, gint_t fallback ) return r_val; } +gint_t bli_env_get_var_arch_type( const char* env, gint_t fallback ) +{ + gint_t r_val; + char* str; + int i, size; + + // Query the environment variable and store the result in str. + str = getenv( env ); + + // Set the return value based on the string obtained from getenv(). + if ( str != NULL ) + { + // If there was no error, convert the string to an integer and + // prepare to return that integer. + r_val = ( gint_t )strtol( str, NULL, 10 ); + + if (r_val == 0) + { + // Could be deliberately 0 (now meaning an ERROR) + // or a non-numeric value. We still allow direct + // specification of integer value to select code + // path. Non-zero integer values bypass this code + // block and are handled as before. Here we look + // for known meaningful names, and return 0 if + // we cannot find a match. This code MUST be kept + // in synch with arch_t enumeration in + // bli_type_defs.h and array config_name in bli_arch.c + + // convert string to lowercase + size = strlen(str); + for (i=0;i<=size;i++) + { + str[i] = tolower(str[i]); + } + + // Intel + if (strcmp(str, "skx") == 0) + { + r_val = BLIS_ARCH_SKX; + } + else if (strcmp(str, "knl") == 0) + { + r_val = BLIS_ARCH_KNL; + } + else if (strcmp(str, "knc") == 0) + { + r_val = BLIS_ARCH_KNC; + } + else if (strcmp(str, "haswell") == 0) + { + r_val = BLIS_ARCH_HASWELL; + } + else if (strcmp(str, "sandybridge") == 0) + { + r_val = BLIS_ARCH_SANDYBRIDGE; + } + else if (strcmp(str, "penryn") == 0) + { + r_val = BLIS_ARCH_PENRYN; + } + // AMD + else if (strcmp(str, "zen4") == 0) + { + r_val = BLIS_ARCH_ZEN4; + } + else if (strcmp(str, "zen3") == 0) + { + r_val = BLIS_ARCH_ZEN3; + } + else if (strcmp(str, "zen2") == 0) + { + r_val = BLIS_ARCH_ZEN2; + } + else if ((strcmp(str, "zen") == 0) || + (strcmp(str, "zen1") == 0)) + { + r_val = BLIS_ARCH_ZEN; + } + else if (strcmp(str, "excavator") == 0) + { + r_val = BLIS_ARCH_EXCAVATOR; + } + else if (strcmp(str, "steamroller") == 0) + { + r_val = BLIS_ARCH_STEAMROLLER; + } + else if (strcmp(str, "piledriver") == 0) + { + r_val = BLIS_ARCH_PILEDRIVER; + } + else if (strcmp(str, "bulldozer") == 0) + { + r_val = BLIS_ARCH_BULLDOZER; + } + // ARM + else if (strcmp(str, "thunderx2") == 0) + { + r_val = BLIS_ARCH_THUNDERX2; + } + else if (strcmp(str, "cortexa57") == 0) + { + r_val = BLIS_ARCH_CORTEXA57; + } + else if (strcmp(str, "cortexa53") == 0) + { + r_val = BLIS_ARCH_CORTEXA53; + } + else if (strcmp(str, "cortexa15") == 0) + { + r_val = BLIS_ARCH_CORTEXA15; + } + else if (strcmp(str, "cortexa9") == 0) + { + r_val = BLIS_ARCH_CORTEXA9; + } + // IBM POWER + else if (strcmp(str, "power10") == 0) + { + r_val = BLIS_ARCH_POWER10; + } + else if (strcmp(str, "power9") == 0) + { + r_val = BLIS_ARCH_POWER9; + } + else if (strcmp(str, "power7") == 0) + { + r_val = BLIS_ARCH_POWER7; + } + else if (strcmp(str, "bgq") == 0) + { + r_val = BLIS_ARCH_BGQ; + } + // Generic + else if (strcmp(str, "generic") == 0) + { + r_val = BLIS_ARCH_GENERIC; + } + + // No else case means we return r_val=0, i.e. this behaves + // the same as generic bli_env_get_var(). + } + } + else + { + // If there was an error, use the "fallback" as the return value. + r_val = fallback; + } + + return r_val; +} + #if 0 #ifdef _MSC_VER #define strerror_r(errno,buf,len) strerror_s(buf,len,errno) diff --git a/frame/base/bli_env.h b/frame/base/bli_env.h index de86fadff0..eaa778cd20 100644 --- a/frame/base/bli_env.h +++ b/frame/base/bli_env.h @@ -6,7 +6,7 @@ Copyright (C) 2014, The University of Texas at Austin Copyright (C) 2016, Hewlett Packard Enterprise Development LP - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -40,5 +40,7 @@ gint_t bli_env_get_var( const char* env, gint_t fallback ); //void bli_env_set_var( const char* env, dim_t value ); +gint_t bli_env_get_var_arch_type( const char* env, gint_t fallback ); + #endif diff --git a/frame/base/bli_gks.c b/frame/base/bli_gks.c index 746b141a93..acb36d306f 100644 --- a/frame/base/bli_gks.c +++ b/frame/base/bli_gks.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018-2020, Advanced Micro Devices, Inc. + Copyright (C) 2018-2021, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -97,6 +97,11 @@ void bli_gks_init( void ) #endif // AMD architectures +#ifdef BLIS_CONFIG_ZEN4 + bli_gks_register_cntx( BLIS_ARCH_ZEN4, bli_cntx_init_zen4, + bli_cntx_init_zen4_ref, + bli_cntx_init_zen4_ind ); +#endif #ifdef BLIS_CONFIG_ZEN3 bli_gks_register_cntx( BLIS_ARCH_ZEN3, bli_cntx_init_zen3, bli_cntx_init_zen3_ref, @@ -165,7 +170,7 @@ void bli_gks_init( void ) bli_gks_register_cntx( BLIS_ARCH_POWER10, bli_cntx_init_power10, bli_cntx_init_power10_ref, bli_cntx_init_power10_ind ); -#endif +#endif #ifdef BLIS_CONFIG_POWER9 bli_gks_register_cntx( BLIS_ARCH_POWER9, bli_cntx_init_power9, bli_cntx_init_power9_ref, @@ -247,7 +252,7 @@ void bli_gks_finalize( void ) void bli_gks_init_index( void ) { // This function is called by bli_gks_init(). It simply initializes all - // architecture id elements of the internal arrays to NULL. + // architecture id elements of the internal arrays to NULL. const size_t gks_size = sizeof( cntx_t* ) * BLIS_NUM_ARCHS; const size_t fpa_size = sizeof( void_fp ) * BLIS_NUM_ARCHS; @@ -360,7 +365,7 @@ void bli_gks_register_cntx // functions for reference kernels and induced method execution. The // former will be used whenever we need to obtain reference kernels and // latter will be used later on if the user calls a level-3 function - // with induced execution enabled. + // with induced execution enabled. cntx_ref_init[ id ] = ref_fp; cntx_ind_init[ id ] = ind_fp; @@ -554,7 +559,7 @@ cntx_t* bli_gks_query_ind_cntx // function on the newly allocated structure, we must first copy // over the contents of the native context. *gks_id_ind = *gks_id_nat; - + // Use the architecture id to look up the function pointer to the // context initialization function for induced methods. ind_cntx_init_ft f = cntx_ind_init[ id ]; diff --git a/frame/base/bli_info.c b/frame/base/bli_info.c index bfd6f6fcc8..cc350ab606 100644 --- a/frame/base/bli_info.c +++ b/frame/base/bli_info.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/frame/base/bli_init.c b/frame/base/bli_init.c index 1207058f12..b037fbd217 100644 --- a/frame/base/bli_init.c +++ b/frame/base/bli_init.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2020 - 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -74,7 +74,7 @@ void bli_finalize_auto( void ) void bli_init_apis( void ) { - /* Initialzie DTL Libary with trace level set by the user */ + /* Initialize DTL Library with trace level set by the user */ AOCL_DTL_INITIALIZE(AOCL_DTL_TRACE_LEVEL); // Initialize various sub-APIs. bli_gks_init(); diff --git a/frame/base/bli_pool.c b/frame/base/bli_pool.c index 7e561983c6..8e11b451f6 100644 --- a/frame/base/bli_pool.c +++ b/frame/base/bli_pool.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2018 - 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -52,6 +52,11 @@ void bli_pool_init // Make sure that block_ptrs_len is at least num_blocks. block_ptrs_len = bli_max( block_ptrs_len, num_blocks ); + // Handle the case where block_ptrs_len is zero, we explicitly set it to 1, + // to avoid any malloc() with zero size, whose behavior is not fixed, and + // also to prevent from falling into any further memory corruption bug. + block_ptrs_len = ( block_ptrs_len == 0 ) ? 1 : block_ptrs_len; + #ifdef BLIS_ENABLE_MEM_TRACING printf( "bli_pool_init(): allocating block_ptrs (length %d): ", ( int )block_ptrs_len ); @@ -373,7 +378,15 @@ void bli_pool_grow { // To prevent this from happening often, we double the current // length of the block_ptrs array. - const siz_t block_ptrs_len_new = 2 * block_ptrs_len_cur; + // Sanity: make sure that the block_ptrs_len_new will be at least + // num_blocks_new, in case doubling the block_ptrs_len_cur is not enough. + // Example 1: + // - block_ptrs_len_cur == num_blocks_cur == 0 and num_blocks_add = 1 + // - So doubling: 2 * block_ptrs_len_cur = 0, whereas 1 is expected + // Example 2: + // - block_ptrs_len_cur == num_blocks_cur == 10 and num_blocks_add = 30 + // - So doubling: 2 * block_ptrs_len_cur = 20, whereas 40 is expected + const siz_t block_ptrs_len_new = bli_max( (2 * block_ptrs_len_cur), num_blocks_new ); #ifdef BLIS_ENABLE_MEM_TRACING printf( "bli_pool_grow(): growing block_ptrs_len (%d -> %d): ", diff --git a/frame/base/bli_rntm.c b/frame/base/bli_rntm.c index fbf5654b7a..c6d2cf5b4a 100644 --- a/frame/base/bli_rntm.c +++ b/frame/base/bli_rntm.c @@ -59,12 +59,20 @@ void bli_rntm_init_from_global( rntm_t* rntm ) // Acquire the mutex protecting global_rntm. bli_pthread_mutex_lock( &global_rntm_mutex ); - // Update the latest value of number of threads into global rntm structure, - // before copying into local rntm structure. This updated value will be - // used in the subsequent parallel regions. + // If BLIS_NUM_THREADS environment variable is not set or + // if bli_thread_set_num_threads() API is not used by the + // application, blis_mt flag will be false. + // Then we derive number of threads using OpenMP API + // omp_get_max_threads(), and update into global rntm structure, + // before copying into local rntm structure. + + // This updated value will be used in the subsequent parallel regions. + if(!(global_rntm.blis_mt)) + { #ifdef BLIS_ENABLE_OPENMP - global_rntm.num_threads = n_threads; + global_rntm.num_threads = n_threads; #endif + } *rntm = global_rntm; @@ -624,14 +632,29 @@ void bli_nthreads_optimum( dim_t n = bli_obj_width(c); dim_t k = bli_obj_width_after_trans(a); - if((m<=128 || n<=128 || k<=128) && ((m+n+k) <= 400) ) + if((m<=128 || n<=128 || k<=128) && ((m+n+k) <= 400)) { n_threads_ideal = 8; } - else if((m<=256 || n<=256 || k<=256) && ((m+n+k) <= 800) ) + else if((m<=256 || n<=256 || k<=256) && ((m+n+k) <= 800)) { n_threads_ideal = 16; } + if((m<=48) || (n<=48) || (k<=48)) + { + if((m+n+k) <= 840) + { + n_threads_ideal = 8; + } + else if((m+n+k) <= 1240) + { + n_threads_ideal = 16; + } + else if((m+n+k) <= 1540) + { + n_threads_ideal = 32; + } + } } else if( family == BLIS_SYRK && bli_obj_is_double(c)) { @@ -679,17 +702,87 @@ void bli_nthreads_optimum( { dim_t n = bli_obj_length(c); dim_t k = bli_obj_width_after_trans(a); - dim_t product = (n*k)>>4; /* product is derived based on n and k */ - //Limit the number thread for smaller sizes: - if(product <= 346) + if ( n < 8 ) { - n_threads_ideal = 1; + if ( k <= 512) + { + n_threads_ideal = 1; + } + else if ( k <= 1024 ) + { + n_threads_ideal = 4; + } } - /* finer threshold needs to set for max_thread cap of 2,3,4,5,6..32 */ - else + else if ( n < 32 ) { - n_threads_ideal = n_threads; + if ( k < 128 ) + { + n_threads_ideal = 1; + } + else if ( k <= 512 ) + { + n_threads_ideal = 4; + } + else if ( k <= 1024 ) + { + n_threads_ideal = 6; + } + else if ( k <= 1600 ) + { + n_threads_ideal = 10; + } + } + else if ( n <= 40 ) + { + if ( k < 32 ) + { + n_threads_ideal = 2; + } + else if ( k < 128 ) + { + n_threads_ideal = 4; + } + else if ( k <= 256 ) + { + n_threads_ideal = 8; + } + } + else if ( n < 115 ) + { + if ( k < 128 ) + { + n_threads_ideal = 6; + } + else if ( k <= 216 ) + { + n_threads_ideal = 8; + } + } + else if ( n <= 160 ) + { + if ( k <= 132 ) + { + n_threads_ideal = 8; + } + } + else if ( n < 176 ) + { + if ( k < 128 ) + { + n_threads_ideal = 8; + } + else if ( k <= 512 ) + { + n_threads_ideal = 14; + } + } + else if ( n <= 220 ) + { + if ( k < 128 ) + { + n_threads_ideal = 8; + } } } else if( family == BLIS_TRMM && bli_obj_is_double(c)) diff --git a/frame/base/bli_rntm.h b/frame/base/bli_rntm.h index e28463c5ab..c45184c57d 100644 --- a/frame/base/bli_rntm.h +++ b/frame/base/bli_rntm.h @@ -66,6 +66,11 @@ BLIS_INLINE bool bli_rntm_auto_factor( rntm_t* rntm ) return rntm->auto_factor; } +BLIS_INLINE bool bli_rntm_blis_mt( rntm_t* rntm ) +{ + return rntm->blis_mt; +} + BLIS_INLINE dim_t bli_rntm_num_threads( rntm_t* rntm ) { return rntm->num_threads; @@ -154,6 +159,11 @@ BLIS_INLINE void bli_rntm_set_auto_factor_only( bool auto_factor, rntm_t* rntm ) rntm->auto_factor = auto_factor; } +BLIS_INLINE void bli_rntm_set_blis_mt_only( bool blis_mt, rntm_t* rntm ) +{ + rntm->blis_mt = blis_mt; +} + BLIS_INLINE void bli_rntm_set_num_threads_only( dim_t nt, rntm_t* rntm ) { rntm->num_threads = nt; diff --git a/frame/compat/CMakeLists.txt b/frame/compat/CMakeLists.txt index 48b66acbcb..bfe8e10508 100644 --- a/frame/compat/CMakeLists.txt +++ b/frame/compat/CMakeLists.txt @@ -35,7 +35,8 @@ ${CMAKE_CURRENT_SOURCE_DIR}/bla_omatadd.c # Select AMD specific sources for AMD configurations. if(${TARGET_ARCH} STREQUAL zen OR ${TARGET_ARCH} STREQUAL zen2 OR -${TARGET_ARCH} STREQUAL zen3 OR +${TARGET_ARCH} STREQUAL zen3 OR +${TARGET_ARCH} STREQUAL zen4 OR ${TARGET_ARCH} STREQUAL amdzen) target_sources("${PROJECT_NAME}" PRIVATE diff --git a/frame/compat/bla_amax_amd.c b/frame/compat/bla_amax_amd.c index 7f1a771f7c..2f7c2d2491 100644 --- a/frame/compat/bla_amax_amd.c +++ b/frame/compat/bla_amax_amd.c @@ -162,13 +162,15 @@ f77_int isamax_ // Non-AVX platforms will use the kernels derived from the context. if (bli_cpuid_is_avx_supported() == TRUE) { + cntx_t* cntx = bli_gks_query_cntx(); + samaxv_ker_ft f = bli_cntx_get_l1v_ker_dt(BLIS_FLOAT, BLIS_AMAXV_KER, cntx ); /* Call BLIS kernel */ - bli_samaxv_zen_int + f ( - n0, - x0, incx0, - &bli_index, - NULL + n0, + x0, incx0, + &bli_index, + NULL ); } else @@ -258,13 +260,15 @@ f77_int idamax_ // Non-AVX platforms will use the kernels derived from the context. if (bli_cpuid_is_avx_supported() == TRUE) { + cntx_t* cntx = bli_gks_query_cntx(); + damaxv_ker_ft f = bli_cntx_get_l1v_ker_dt(BLIS_DOUBLE, BLIS_AMAXV_KER, cntx ); /* Call BLIS kernel */ - bli_damaxv_zen_int + f ( - n0, - x0, incx0, - &bli_index, - NULL + n0, + x0, incx0, + &bli_index, + NULL ); } else diff --git a/frame/compat/bla_gemm.c b/frame/compat/bla_gemm.c index 406ff69d53..931c80243a 100644 --- a/frame/compat/bla_gemm.c +++ b/frame/compat/bla_gemm.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2019 - 22, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2019 - 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -86,6 +86,17 @@ void PASTEF77(ch,blasname) \ ldb, \ ldc \ ); \ +\ + /* Quick return if possible. */ \ + if ( *m == 0 || *n == 0 || (( PASTEMAC(ch,eq0)( *alpha ) || *k == 0) \ + && PASTEMAC(ch,eq1)( *beta ) )) \ + { \ + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); \ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); \ + /* Finalize BLIS. */ \ + bli_finalize_auto(); \ + return; \ + } \ \ /* Map BLAS chars to their corresponding BLIS enumerated type value. */ \ bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); \ @@ -175,6 +186,17 @@ void PASTEF77(ch,blasname) \ ldb, \ ldc \ ); \ +\ + /* Quick return if possible. */ \ + if ( *m == 0 || *n == 0 || (( PASTEMAC(ch,eq0)( *alpha ) || *k == 0) \ + && PASTEMAC(ch,eq1)( *beta ) )) \ + { \ + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); \ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); \ + /* Finalize BLIS. */ \ + bli_finalize_auto(); \ + return; \ + } \ \ /* Map BLAS chars to their corresponding BLIS enumerated type value. */ \ bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); \ @@ -302,7 +324,6 @@ void PASTEF77(ch,blasname) \ #ifdef BLIS_ENABLE_BLAS INSERT_GENTFUNC_BLAS( gemm,gemm ) -#if 1 void dzgemm_ ( const f77_char* transa, @@ -344,6 +365,17 @@ void dzgemm_ ldc ); + /* Quick return if possible. */ + if ( *m == 0 || *n == 0 || (( PASTEMAC(z,eq0)( *alpha ) || *k == 0) + && PASTEMAC(z,eq1)( *beta ) )) + { + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + /* Finalize BLIS. */ + bli_finalize_auto(); + return; + } + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); bli_param_map_netlib_to_blis_trans( *transb, &blis_transb ); @@ -394,5 +426,5 @@ void dzgemm_ /* Finalize BLIS. */ bli_finalize_auto(); }// end of dzgemm_ -#endif + #endif diff --git a/frame/compat/bla_gemm3m.c b/frame/compat/bla_gemm3m.c index e51cc314de..665c8643dd 100644 --- a/frame/compat/bla_gemm3m.c +++ b/frame/compat/bla_gemm3m.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020 - 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -83,6 +83,16 @@ void PASTEF77(ch,blasname) \ ldb, \ ldc \ ); \ +\ + /* Quick return if possible. */ \ + if ( *m == 0 || *n == 0 || (( PASTEMAC(ch,eq0)( *alpha ) || *k == 0) \ + && PASTEMAC(ch,eq1)( *beta ) )) \ + { \ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); \ + /* Finalize BLIS. */ \ + bli_finalize_auto(); \ + return; \ + } \ \ /* Map BLAS chars to their corresponding BLIS enumerated type value. */ \ bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); \ @@ -164,6 +174,16 @@ void PASTEF77(ch,blasname) \ ldb, \ ldc \ ); \ +\ + /* Quick return if possible. */ \ + if ( *m == 0 || *n == 0 || (( PASTEMAC(ch,eq0)( *alpha ) || *k == 0) \ + && PASTEMAC(ch,eq1)( *beta ) )) \ + { \ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); \ + /* Finalize BLIS. */ \ + bli_finalize_auto(); \ + return; \ + } \ \ /* Map BLAS chars to their corresponding BLIS enumerated type value. */ \ bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); \ diff --git a/frame/compat/bla_gemm_amd.c b/frame/compat/bla_gemm_amd.c index 99d7371778..a9478581ef 100644 --- a/frame/compat/bla_gemm_amd.c +++ b/frame/compat/bla_gemm_amd.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2019 - 22, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2019 - 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -86,6 +86,17 @@ void PASTEF77(ch,blasname) \ ldb, \ ldc \ ); \ +\ + /* Quick return if possible. */ \ + if ( *m == 0 || *n == 0 || (( PASTEMAC(ch,eq0)( *alpha ) || *k == 0) \ + && PASTEMAC(ch,eq1)( *beta ) )) \ + { \ + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); \ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); \ + /* Finalize BLIS. */ \ + bli_finalize_auto(); \ + return; \ + } \ \ /* Map BLAS chars to their corresponding BLIS enumerated type value. */ \ bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); \ @@ -175,6 +186,17 @@ void PASTEF77(ch,blasname) \ ldb, \ ldc \ ); \ +\ + /* Quick return if possible. */ \ + if ( *m == 0 || *n == 0 || (( PASTEMAC(ch,eq0)( *alpha ) || *k == 0) \ + && PASTEMAC(ch,eq1)( *beta ) )) \ + { \ + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); \ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); \ + /* Finalize BLIS. */ \ + bli_finalize_auto(); \ + return; \ + } \ \ /* Map BLAS chars to their corresponding BLIS enumerated type value. */ \ bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); \ @@ -343,6 +365,16 @@ void dgemm_ ldc ); + /* Quick return if possible. */ + if ( *m == 0 || *n == 0 || ((*alpha == 0.0 || *k == 0) && *beta == 1.0)) + { + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + /* Finalize BLIS. */ + bli_finalize_auto(); + return; + } + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ bli_param_map_netlib_to_blis_trans(*transa, &blis_transa); bli_param_map_netlib_to_blis_trans(*transb, &blis_transb); @@ -556,9 +588,10 @@ void dgemm_ #ifdef BLIS_ENABLE_SMALL_MATRIX - //if( ((m0 + n0 -k0) < 2000) && ((m0 + k0-n0) < 2000) && ((n0 + k0-m0) < 2000) && (n0 > 2)) - if( ( ( (m0 + n0 -k0) < 2000) && ((m0 + k0-n0) < 2000) && ((n0 + k0-m0) < 2000) ) || - ((n0 <= 10) && (k0 <=10)) ) + if(((m0 == n0) && (m0 < 400) && (k0 < 1000)) || + ( (m0 != n0) && (( ((m0 + n0 -k0) < 1500) && + ((m0 + k0-n0) < 1500) && ((n0 + k0-m0) < 1500) ) || + ((n0 <= 100) && (k0 <=100))))) { err_t status = BLIS_FAILURE; if (bli_is_notrans(blis_transa)) @@ -665,6 +698,17 @@ void zgemm_ ldc ); + /* Quick return if possible. */ + if ( *m == 0 || *n == 0 || (( PASTEMAC(z,eq0)( *alpha ) || *k == 0) + && PASTEMAC(z,eq1)( *beta ) )) + { + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + /* Finalize BLIS. */ + bli_finalize_auto(); + return; + } + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); bli_param_map_netlib_to_blis_trans( *transb, &blis_transb ); @@ -713,11 +757,72 @@ void zgemm_ //dim_t nt = bli_thread_get_num_threads(); // get number of threads bool nt = bli_thread_get_is_parallel(); // Check if parallel zgemm is invoked. + /* + Invoking the API for input sizes with k=1. + - For single thread, the API has no constraints before invoking. + - For multiple threads, the constraint is that m and n should individually be less than 128. + */ + if((k0 == 1) && ((nt == 0) || ((nt == 1) && (m0 < 128) && (n0 < 128))) + && bli_is_notrans(blis_transa) + && bli_is_notrans(blis_transb)) + { + bli_zgemm_ref_k1_nn( m0, n0, k0, + (dcomplex*)alpha, + (dcomplex*)a, *lda, + (dcomplex*)b, *ldb, + (dcomplex*)beta, + c, *ldc); + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + /* Finalize BLIS */ + bli_finalize_auto(); + + return; + } + + /* Call Gemv when m/n=1 */ + if (n0 == 1) + { + if (bli_is_notrans(blis_transa)) + { + bli_zgemv_unf_var2( + BLIS_NO_TRANSPOSE, + bli_extract_conj(blis_transb), + m0, k0, + (dcomplex *)alpha, + (dcomplex *)a, rs_a, cs_a, + (dcomplex *)b, bli_is_notrans(blis_transb) ? rs_b : cs_b, + (dcomplex *)beta, + c, rs_c, + ((void *)0)); + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); + + return; + } + } + else if (m0 == 1) + { + if (bli_is_trans(blis_transb)) + { + bli_zgemv_unf_var2( + blis_transb, + bli_extract_conj(blis_transa), + k0, n0, + (dcomplex *)alpha, + (dcomplex *)b, cs_b, rs_b, + (dcomplex *)a, bli_is_notrans(blis_transa) ? cs_a : rs_a, + (dcomplex *)beta, + c, cs_c, + ((void *)0)); + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); + return; + } + } + #ifdef BLIS_ENABLE_SMALL_MATRIX - if( ( (nt == 0) && (m0 <= 512 ) && ( n0 <= 512 ) && ( k0 <= 512 ) ) || - ( (nt == 1) && ((( m0 <= 32)||(n0 <= 32)||(k0 <=32)) && ((m0+n0+k0)<=100)) ) - ) + if (((nt == 0) && ((m0 <= 512) && (n0 <= 512) && (k0 <= 512))) || + ((nt == 1) && (((m0 <= 32) || (n0 <= 32) || (k0 <= 32)) && ((m0 + n0 + k0) <= 100)))) { err_t status = BLIS_NOT_YET_IMPLEMENTED; if (bli_is_notrans(blis_transa)) @@ -753,6 +858,19 @@ void zgemm_ } } #endif + + // disabling sup path for single thread in zgemm until further tuning. + if (nt == 1) + { + err_t status = bli_gemmsup(&alphao, &ao, &bo, &betao, &co, NULL, NULL); + if (status == BLIS_SUCCESS) + { + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) + return; + } + } + // fall back on native path when zgemm is not handled in sup path. bli_gemmnat(&alphao, &ao, &bo, &betao, &co, NULL, NULL); AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); @@ -768,10 +886,6 @@ void zgemm_ INSERT_GENTFUNC_BLAS_SC( gemm, gemm ) - -// Observed a regression in dgemm with this function addition. -// Disabling temporarily. -#if 1 void dzgemm_ ( const f77_char* transa, @@ -813,6 +927,17 @@ void dzgemm_ ldc ); + /* Quick return if possible. */ + if ( *m == 0 || *n == 0 || (( PASTEMAC(z,eq0)( *alpha ) || *k == 0) + && PASTEMAC(z,eq1)( *beta ) )) + { + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + /* Finalize BLIS. */ + bli_finalize_auto(); + return; + } + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); bli_param_map_netlib_to_blis_trans( *transb, &blis_transb ); @@ -863,5 +988,5 @@ void dzgemm_ /* Finalize BLIS. */ bli_finalize_auto(); }// end of dzgemm_ -#endif + #endif diff --git a/frame/compat/bla_gemmt.c b/frame/compat/bla_gemmt.c index e51b943667..7abad40acf 100644 --- a/frame/compat/bla_gemmt.c +++ b/frame/compat/bla_gemmt.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2020, Advanced Micro Devices, Inc. + Copyright (C) 2020 - 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -84,6 +84,16 @@ void PASTEF77(ch,blasname) \ ldb, \ ldc \ ); \ +\ + /* Quick return if possible. */ \ + if ( *n == 0 || (( PASTEMAC(ch,eq0)( *alpha ) || *k == 0) \ + && PASTEMAC(ch,eq1)( *beta ) )) \ + { \ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); \ + /* Finalize BLIS. */ \ + bli_finalize_auto(); \ + return; \ + } \ \ /* Map BLAS chars to their corresponding BLIS enumerated type value. */ \ bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); \ @@ -170,6 +180,16 @@ void PASTEF77(ch,blasname) \ ldb, \ ldc \ ); \ +\ + /* Quick return if possible. */ \ + if ( *n == 0 || (( PASTEMAC(ch,eq0)( *alpha ) || *k == 0) \ + && PASTEMAC(ch,eq1)( *beta ) )) \ + { \ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); \ + /* Finalize BLIS. */ \ + bli_finalize_auto(); \ + return; \ + } \ \ /* Map BLAS chars to their corresponding BLIS enumerated type value. */ \ bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); \ diff --git a/frame/compat/bla_hemm.c b/frame/compat/bla_hemm.c index fcd7858731..0e003012d2 100644 --- a/frame/compat/bla_hemm.c +++ b/frame/compat/bla_hemm.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2019 - 2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2019 - 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -84,6 +84,16 @@ void PASTEF77(ch,blasname) \ ldb, \ ldc \ ); \ +\ + /* Quick return if possible. */ \ + if ( *m == 0 || *n == 0 || ( PASTEMAC(ch,eq0)( *alpha ) \ + && PASTEMAC(ch,eq1)( *beta ) )) \ + { \ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); \ + /* Finalize BLIS. */ \ + bli_finalize_auto(); \ + return; \ + } \ \ /* Map BLAS chars to their corresponding BLIS enumerated type value. */ \ bli_param_map_netlib_to_blis_side( *side, &blis_side ); \ @@ -165,6 +175,16 @@ void PASTEF77(ch,blasname) \ ldb, \ ldc \ ); \ +\ + /* Quick return if possible. */ \ + if ( *m == 0 || *n == 0 || ( PASTEMAC(ch,eq0)( *alpha ) \ + && PASTEMAC(ch,eq1)( *beta ) )) \ + { \ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); \ + /* Finalize BLIS. */ \ + bli_finalize_auto(); \ + return; \ + } \ \ /* Map BLAS chars to their corresponding BLIS enumerated type value. */ \ bli_param_map_netlib_to_blis_side( *side, &blis_side ); \ diff --git a/frame/compat/bla_symm.c b/frame/compat/bla_symm.c index 078cbf743c..85aebb435f 100755 --- a/frame/compat/bla_symm.c +++ b/frame/compat/bla_symm.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2019 - 2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2019 - 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -83,6 +83,16 @@ void PASTEF77(ch,blasname) \ ldb, \ ldc \ ); \ +\ + /* Quick return if possible. */ \ + if ( *m == 0 || *n == 0 || ( PASTEMAC(ch,eq0)( *alpha ) \ + && PASTEMAC(ch,eq1)( *beta ) )) \ + { \ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); \ + /* Finalize BLIS. */ \ + bli_finalize_auto(); \ + return; \ + } \ \ /* Map BLAS chars to their corresponding BLIS enumerated type value. */ \ bli_param_map_netlib_to_blis_side( *side, &blis_side ); \ @@ -163,6 +173,16 @@ void PASTEF77(ch,blasname) \ ldb, \ ldc \ ); \ +\ + /* Quick return if possible. */ \ + if ( *m == 0 || *n == 0 || ( PASTEMAC(ch,eq0)( *alpha ) \ + && PASTEMAC(ch,eq1)( *beta ) )) \ + { \ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); \ + /* Finalize BLIS. */ \ + bli_finalize_auto(); \ + return; \ + } \ \ /* Map BLAS chars to their corresponding BLIS enumerated type value. */ \ bli_param_map_netlib_to_blis_side( *side, &blis_side ); \ diff --git a/frame/compat/bla_syr2k.c b/frame/compat/bla_syr2k.c index b2280423a7..6a4f31b969 100644 --- a/frame/compat/bla_syr2k.c +++ b/frame/compat/bla_syr2k.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin. - Copyright (C) 2019 - 2021, Advanced Micro Devices, Inc.All Rights Reserved. + Copyright (C) 2019 - 2022, Advanced Micro Devices, Inc.All Rights Reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -83,6 +83,16 @@ void PASTEF77(ch,blasname) \ ldb, \ ldc \ ); \ +\ + /* Quick return if possible. */ \ + if ( *m == 0 || (( PASTEMAC(ch,eq0)( *alpha ) || *k == 0) \ + && PASTEMAC(ch,eq1)( *beta ) )) \ + { \ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); \ + /* Finalize BLIS. */ \ + bli_finalize_auto(); \ + return; \ + } \ \ /* Map BLAS chars to their corresponding BLIS enumerated type value. */ \ bli_param_map_netlib_to_blis_uplo( *uploc, &blis_uploc ); \ @@ -172,6 +182,16 @@ void PASTEF77(ch,blasname) \ ldb, \ ldc \ ); \ +\ + /* Quick return if possible. */ \ + if ( *m == 0 || (( PASTEMAC(ch,eq0)( *alpha ) || *k == 0) \ + && PASTEMAC(ch,eq1)( *beta ) )) \ + { \ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); \ + /* Finalize BLIS. */ \ + bli_finalize_auto(); \ + return; \ + } \ \ /* Map BLAS chars to their corresponding BLIS enumerated type value. */ \ bli_param_map_netlib_to_blis_uplo( *uploc, &blis_uploc ); \ diff --git a/frame/compat/bla_syrk.c b/frame/compat/bla_syrk.c index 547fceaa79..376b23aec9 100644 --- a/frame/compat/bla_syrk.c +++ b/frame/compat/bla_syrk.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin. - Copyright (C) 2019 - 2021, Advanced Micro Devices, Inc.All Rights Reserved. + Copyright (C) 2019 - 2022, Advanced Micro Devices, Inc.All Rights Reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -81,6 +81,16 @@ void PASTEF77(ch,blasname) \ lda, \ ldc \ ); \ +\ + /* Quick return if possible. */ \ + if ( *m == 0 || (( PASTEMAC(ch,eq0)( *alpha ) || *k == 0) \ + && PASTEMAC(ch,eq1)( *beta ) )) \ + { \ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); \ + /* Finalize BLIS. */ \ + bli_finalize_auto(); \ + return; \ + } \ \ /* Map BLAS chars to their corresponding BLIS enumerated type value. */ \ bli_param_map_netlib_to_blis_uplo( *uploc, &blis_uploc ); \ @@ -164,6 +174,16 @@ void PASTEF77(ch,blasname) \ lda, \ ldc \ ); \ +\ + /* Quick return if possible. */ \ + if ( *m == 0 || (( PASTEMAC(ch,eq0)( *alpha ) || *k == 0) \ + && PASTEMAC(ch,eq1)( *beta ) )) \ + { \ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); \ + /* Finalize BLIS. */ \ + bli_finalize_auto(); \ + return; \ + } \ \ /* Map BLAS chars to their corresponding BLIS enumerated type value. */ \ bli_param_map_netlib_to_blis_uplo( *uploc, &blis_uploc ); \ diff --git a/frame/compat/bla_trmm.c b/frame/compat/bla_trmm.c index ee87b96c04..c319b3ab51 100644 --- a/frame/compat/bla_trmm.c +++ b/frame/compat/bla_trmm.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin. - Copyright (C) 2019 - 2021, Advanced Micro Devices, Inc.All Rights Reserved. + Copyright (C) 2019 - 2022, Advanced Micro Devices, Inc.All Rights Reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -86,6 +86,15 @@ void PASTEF77(ch,blasname) \ lda, \ ldb \ ); \ +\ + /* Quick return if possible. */ \ + if ( *m == 0 || *n == 0 ) \ + { \ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); \ + /* Finalize BLIS. */ \ + bli_finalize_auto(); \ + return; \ + } \ \ /* Map BLAS chars to their corresponding BLIS enumerated type value. */ \ bli_param_map_netlib_to_blis_side( *side, &blis_side ); \ @@ -168,6 +177,15 @@ void PASTEF77(ch,blasname) \ lda, \ ldb \ ); \ +\ + /* Quick return if possible. */ \ + if ( *m == 0 || *n == 0 ) \ + { \ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); \ + /* Finalize BLIS. */ \ + bli_finalize_auto(); \ + return; \ + } \ \ /* Map BLAS chars to their corresponding BLIS enumerated type value. */ \ bli_param_map_netlib_to_blis_side( *side, &blis_side ); \ diff --git a/frame/compat/bla_trsm.c b/frame/compat/bla_trsm.c index fea7ba6f17..e99805d8dd 100644 --- a/frame/compat/bla_trsm.c +++ b/frame/compat/bla_trsm.c @@ -85,6 +85,15 @@ void PASTEF77(ch,blasname) \ lda, \ ldb \ ); \ +\ + /* Quick return if possible. */ \ + if ( *m == 0 || *n == 0 ) \ + { \ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); \ + /* Finalize BLIS. */ \ + bli_finalize_auto(); \ + return; \ + } \ \ /* Map BLAS chars to their corresponding BLIS enumerated type value. */ \ bli_param_map_netlib_to_blis_side( *side, &blis_side ); \ @@ -169,6 +178,15 @@ void PASTEF77(ch,blasname) \ lda, \ ldb \ ); \ +\ + /* Quick return if possible. */ \ + if ( *m == 0 || *n == 0 ) \ + { \ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); \ + /* Finalize BLIS. */ \ + bli_finalize_auto(); \ + return; \ + } \ \ /* Map BLAS chars to their corresponding BLIS enumerated type value. */ \ bli_param_map_netlib_to_blis_side( *side, &blis_side ); \ diff --git a/frame/compat/bla_trsm_amd.c b/frame/compat/bla_trsm_amd.c index f479b5eac0..13330a5d08 100644 --- a/frame/compat/bla_trsm_amd.c +++ b/frame/compat/bla_trsm_amd.c @@ -85,6 +85,15 @@ void PASTEF77(ch,blasname) \ lda, \ ldb \ ); \ +\ + /* Quick return if possible. */ \ + if ( *m == 0 || *n == 0 ) \ + { \ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); \ + /* Finalize BLIS. */ \ + bli_finalize_auto(); \ + return; \ + } \ \ /* Map BLAS chars to their corresponding BLIS enumerated type value. */ \ bli_param_map_netlib_to_blis_side( *side, &blis_side ); \ @@ -169,6 +178,15 @@ void PASTEF77(ch,blasname) \ lda, \ ldb \ ); \ +\ + /* Quick return if possible. */ \ + if ( *m == 0 || *n == 0 ) \ + { \ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); \ + /* Finalize BLIS. */ \ + bli_finalize_auto(); \ + return; \ + } \ \ /* Map BLAS chars to their corresponding BLIS enumerated type value. */ \ bli_param_map_netlib_to_blis_side( *side, &blis_side ); \ @@ -424,6 +442,15 @@ void strsm_ ldb ); + /* Quick return if possible. */ + if ( *m == 0 || *n == 0 ) + { + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + /* Finalize BLIS. */ + bli_finalize_auto(); + return; + } + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ bli_param_map_netlib_to_blis_side( *side, &blis_side ); bli_param_map_netlib_to_blis_uplo( *uploa, &blis_uploa ); @@ -686,6 +713,15 @@ void dtrsm_ ldb ); + /* Quick return if possible. */ + if ( *m == 0 || *n == 0 ) + { + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + /* Finalize BLIS. */ + bli_finalize_auto(); + return; + } + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ bli_param_map_netlib_to_blis_side( *side, &blis_side ); bli_param_map_netlib_to_blis_uplo( *uploa, &blis_uploa ); @@ -886,40 +922,6 @@ void dtrsm_ return; } } - - // bli_trsm_small_mt is performing better than native multithread - // for certain sizes of m & n. -#ifdef BLIS_ENABLE_OPENMP - rntm_t rntm; - bli_rntm_init_from_global( &rntm ); - - // Query the total number of threads from the rntm_t object. - dim_t n_threads = bli_rntm_num_threads( &rntm ); - if ( ( (n_threads > 1) && (m0 <= 1500) && (n0 <= 1500) ) || - ( (n_threads == 32) && (m0 <= 2300) && (n0 <= 2300) ) || - ( (n_threads == 16) && (m0 <= 3800) && (n0 <= 3800) ) || - ( (n_threads == 8) && (m0 <= 2800) && (n0 <= 2800) ) || - ( (n_threads == 4) && (m0 <= 2000) && (n0 <= 2000) ) || - ( (n_threads == 2) && (m0 <= 2000) && (n0 <= 2000) ) ) - { - err_t status; - status = bli_trsm_small_mt( - blis_side, - &alphao, - &ao, - &bo, - NULL, - NULL); - - if ( status == BLIS_SUCCESS ) - { - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); - /* Finalize BLIS. */ - bli_finalize_auto(); - return; - } - } -#endif// BLIS_ENABLE_OPENMP } // bli_cpuid_is_avx_supported #endif// END of BLIS_ENABLE_SMALL_MATRIX_TRSM @@ -982,6 +984,15 @@ void ztrsm_ ldb ); + /* Quick return if possible. */ + if ( *m == 0 || *n == 0 ) + { + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + /* Finalize BLIS. */ + bli_finalize_auto(); + return; + } + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ bli_param_map_netlib_to_blis_side( *side, &blis_side ); bli_param_map_netlib_to_blis_uplo( *uploa, &blis_uploa ); @@ -1246,7 +1257,7 @@ void ztrsm_ return; } } - } // bli_cpuid_is_avx_supported} + } // bli_cpuid_is_avx_supported #endif bli_trsmnat @@ -1308,6 +1319,15 @@ void ctrsm_ ldb ); + /* Quick return if possible. */ + if ( *m == 0 || *n == 0 ) + { + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + /* Finalize BLIS. */ + bli_finalize_auto(); + return; + } + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ bli_param_map_netlib_to_blis_side( *side, &blis_side ); bli_param_map_netlib_to_blis_uplo( *uploa, &blis_uploa ); diff --git a/frame/include/bli_arch_config.h b/frame/include/bli_arch_config.h index a62128dffe..787e3879b8 100644 --- a/frame/include/bli_arch_config.h +++ b/frame/include/bli_arch_config.h @@ -6,7 +6,7 @@ Copyright (C) 2014, The University of Texas at Austin Copyright (C) 2016, Hewlett Packard Enterprise Development LP - Copyright (C) 2019 - 2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2019 - 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -62,6 +62,9 @@ CNTX_INIT_PROTS( penryn ) #endif // -- AMD64 architectures -- +#ifdef BLIS_CONFIG_ZEN4 +CNTX_INIT_PROTS( zen4 ) +#endif #ifdef BLIS_CONFIG_ZEN3 CNTX_INIT_PROTS( zen3 ) #endif @@ -168,6 +171,9 @@ CNTX_INIT_PROTS( generic ) // -- AMD64 architectures -- +#ifdef BLIS_FAMILY_ZEN4 +#include "bli_family_zen4.h" +#endif #ifdef BLIS_FAMILY_ZEN3 #include "bli_family_zen3.h" #endif @@ -258,7 +264,9 @@ CNTX_INIT_PROTS( generic ) #endif // -- AMD64 architectures -- - +#ifdef BLIS_KERNELS_ZEN4 +#include "bli_kernels_zen4.h" +#endif #ifdef BLIS_KERNELS_ZEN2 #include "bli_kernels_zen2.h" #endif diff --git a/frame/include/bli_config_macro_defs.h b/frame/include/bli_config_macro_defs.h index c9e597c9a6..dd6e8f6062 100644 --- a/frame/include/bli_config_macro_defs.h +++ b/frame/include/bli_config_macro_defs.h @@ -241,8 +241,9 @@ #endif #endif -#define BLIS_EXPORT_BLIS BLIS_EXPORT -#define BLIS_EXPORT_BLAS BLIS_EXPORT +#define BLIS_EXPORT_BLIS BLIS_EXPORT +#define BLIS_EXPORT_BLAS BLIS_EXPORT +#define BLIS_EXPORT_ADDON BLIS_EXPORT // -- STATIC INLINE FUNCTIONS -------------------------------------------------- diff --git a/frame/include/bli_genarray_macro_defs.h b/frame/include/bli_genarray_macro_defs.h index 1e9c772fa6..c6ba2d7fbb 100644 --- a/frame/include/bli_genarray_macro_defs.h +++ b/frame/include/bli_genarray_macro_defs.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020, Advanced Micro Devices, Inc. + Copyright (C) 2022, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -140,6 +140,20 @@ arrayname[BLIS_NUM_FP_TYPES][BLIS_NUM_FP_TYPES] = \ +// -- One-operand macro (with custom prefix) -- + +#define GENARRAY_PREF(arrayname,prefix,op) \ +\ +arrayname[BLIS_NUM_FP_TYPES] = \ +{ \ + PASTECH2(prefix,s,op), \ + PASTECH2(prefix,c,op), \ + PASTECH2(prefix,d,op), \ + PASTECH2(prefix,z,op) \ +} + + + // -- Two-operand macros -- diff --git a/frame/include/bli_obj_macro_defs.h b/frame/include/bli_obj_macro_defs.h index 2b3ac35ae0..1499dc4182 100644 --- a/frame/include/bli_obj_macro_defs.h +++ b/frame/include/bli_obj_macro_defs.h @@ -6,7 +6,7 @@ Copyright (C) 2014, The University of Texas at Austin Copyright (C) 2016, Hewlett Packard Enterprise Development LP - Copyright (C) 2019, Advanced Micro Devices, Inc. + Copyright (C) 2022, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -1190,7 +1190,7 @@ BLIS_INLINE stor3_t bli_obj_stor3_from_strides( obj_t* c, obj_t* a, obj_t* b ) // -- Initialization-related macros -- // Finish the initialization started by the matrix-specific static initializer -// (e.g. BLIS_OBJECT_PREINITIALIZER) +// (e.g. BLIS_OBJECT_INITIALIZER) // NOTE: This is intended only for use in the BLAS compatibility API and typed // BLIS API. @@ -1223,7 +1223,7 @@ BLIS_INLINE void bli_obj_init_finish( num_t dt, dim_t m, dim_t n, void* p, inc_t } // Finish the initialization started by the 1x1-specific static initializer -// (e.g. BLIS_OBJECT_PREINITIALIZER_1X1) +// (e.g. BLIS_OBJECT_INITIALIZER_1X1) // NOTE: This is intended only for use in the BLAS compatibility API and typed // BLIS API. diff --git a/frame/include/bli_type_defs.h b/frame/include/bli_type_defs.h index 9d45aec1ab..89f9aada33 100644 --- a/frame/include/bli_type_defs.h +++ b/frame/include/bli_type_defs.h @@ -388,7 +388,7 @@ typedef void* void_fp; #define BLIS_BITVAL_SINGLE_PREC 0x0 #define BLIS_BITVAL_DOUBLE_PREC BLIS_PRECISION_BIT #define BLIS_BITVAL_FLOAT_TYPE 0x0 -#define BLIS_BITVAL_SCOMPLEX_TYPE BLIS_DOMAIN_BIT +#define BLIS_BITVAL_SCOMPLEX_TYPE BLIS_DOMAIN_BIT #define BLIS_BITVAL_DOUBLE_TYPE BLIS_PRECISION_BIT #define BLIS_BITVAL_DCOMPLEX_TYPE ( BLIS_DOMAIN_BIT | BLIS_PRECISION_BIT ) #define BLIS_BITVAL_INT_TYPE 0x04 @@ -398,10 +398,10 @@ typedef void* void_fp; #define BLIS_BITVAL_NO_CONJ 0x0 #define BLIS_BITVAL_CONJ BLIS_CONJ_BIT #define BLIS_BITVAL_CONJ_TRANS ( BLIS_CONJ_BIT | BLIS_TRANS_BIT ) -#define BLIS_BITVAL_ZEROS 0x0 +#define BLIS_BITVAL_ZEROS 0x0 #define BLIS_BITVAL_UPPER ( BLIS_UPPER_BIT | BLIS_DIAG_BIT ) #define BLIS_BITVAL_LOWER ( BLIS_LOWER_BIT | BLIS_DIAG_BIT ) -#define BLIS_BITVAL_DENSE BLIS_UPLO_BITS +#define BLIS_BITVAL_DENSE BLIS_UPLO_BITS #define BLIS_BITVAL_NONUNIT_DIAG 0x0 #define BLIS_BITVAL_UNIT_DIAG BLIS_UNIT_DIAG_BIT #define BLIS_BITVAL_INVERT_DIAG BLIS_INVERT_DIAG_BIT @@ -802,10 +802,11 @@ typedef enum BLIS_GEMMTRSM_L_UKR, BLIS_GEMMTRSM_U_UKR, BLIS_TRSM_L_UKR, - BLIS_TRSM_U_UKR + BLIS_TRSM_U_UKR, + BLIS_GEMM_AVX2_UKR } l3ukr_t; -#define BLIS_NUM_LEVEL3_UKRS 5 +#define BLIS_NUM_LEVEL3_UKRS 6 typedef enum @@ -989,12 +990,21 @@ typedef enum // string array in bli_arch.c. Whenever values are added/inserted // OR if values are rearranged, be sure to update the string array // in bli_arch.c. +// This must also be kept up-to-date with the bli_env_get_var_arch_type() +// function in bli_env.c typedef enum { // NOTE: The C language standard guarantees that the first enum value // starts at 0. + // Initial value, will be selected for an unrecognized (non-integer) + // value of BLIS_ARCH_TYPE + BLIS_ARCH_ERROR, + + // Generic architecture/configuration + BLIS_ARCH_GENERIC, + // Intel BLIS_ARCH_SKX, BLIS_ARCH_KNL, @@ -1004,6 +1014,7 @@ typedef enum BLIS_ARCH_PENRYN, // AMD + BLIS_ARCH_ZEN4, BLIS_ARCH_ZEN3, BLIS_ARCH_ZEN2, BLIS_ARCH_ZEN, @@ -1025,12 +1036,13 @@ typedef enum BLIS_ARCH_POWER7, BLIS_ARCH_BGQ, - // Generic architecture/configuration - BLIS_ARCH_GENERIC + // Dummy value, always the last one. + // In config_name in bli_arch.c this is also set to "generic" + BLIS_ARCH_GENERIC_LAST } arch_t; -#define BLIS_NUM_ARCHS (BLIS_ARCH_GENERIC + 1) +#define BLIS_NUM_ARCHS (BLIS_ARCH_GENERIC_LAST + 1) // @@ -1474,6 +1486,8 @@ typedef struct rntm_s bool pack_a; // enable/disable packing of left-hand matrix A. bool pack_b; // enable/disable packing of right-hand matrix B. bool l3_sup; // enable/disable small matrix handling in level-3 ops. + // blis_mt, flag to figure out whether number of + bool blis_mt;// threads is set using BLIS APIS or OpenMP APIs. // "Internal" fields: these should not be exposed to the end-user. @@ -1546,13 +1560,13 @@ typedef enum BLIS_INVALID_COL_STRIDE = ( -51), BLIS_INVALID_DIM_STRIDE_COMBINATION = ( -52), - // Structure-specific errors + // Structure-specific errors BLIS_EXPECTED_GENERAL_OBJECT = ( -60), BLIS_EXPECTED_HERMITIAN_OBJECT = ( -61), BLIS_EXPECTED_SYMMETRIC_OBJECT = ( -62), BLIS_EXPECTED_TRIANGULAR_OBJECT = ( -63), - // Storage-specific errors + // Storage-specific errors BLIS_EXPECTED_UPPER_OR_LOWER_OBJECT = ( -70), // Partitioning-specific errors @@ -1566,7 +1580,7 @@ typedef enum // Packing-specific errors BLIS_PACK_SCHEMA_NOT_SUPPORTED_FOR_UNPACK = (-100), - // Buffer-specific errors + // Buffer-specific errors BLIS_EXPECTED_NONNULL_OBJECT_BUFFER = (-110), // Memory errors diff --git a/frame/include/bli_x86_asm_macros.h b/frame/include/bli_x86_asm_macros.h index a4987b4c5f..ffb2771758 100644 --- a/frame/include/bli_x86_asm_macros.h +++ b/frame/include/bli_x86_asm_macros.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2018, The University of Texas at Austin - Copyright (C) 2019, Advanced Micro Devices, Inc. + Copyright (C) 2019-22, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -1271,6 +1271,7 @@ #else #define KMOVW(_0, _1) INSTR_(kmovw, _0, _1) +#define KMOVQ(_0, _1) INSTR_(kmovq, _0, _1) #define JKNZD(_0, _1) INSTR_(kortestw, _0, _0) INSTR_(jnz, _1) #endif @@ -1279,6 +1280,7 @@ #define KSHIFTRW(_0, _1, _2) INSTR_(kshiftrw, _0, _1, _2) #define kmovw(_0, _1) KMOVW(_0, _1) +#define kmovq(_0, _1) KMOVQ(_0, _1) #define jknzd(_0, _1) JKNZD(_0, _1) #define kxnorw(_0, _1, _2) KXNORW(_0, _1, _2) #define kshiftrw(_0, _1, _2) KSHIFTRW(_0, _1, _2) diff --git a/frame/include/blis.h b/frame/include/blis.h index 783b5de0eb..251503de27 100644 --- a/frame/include/blis.h +++ b/frame/include/blis.h @@ -6,7 +6,7 @@ Copyright (C) 2014, The University of Texas at Austin Copyright (C) 2016, Hewlett Packard Enterprise Development LP - Copyright (C) 2018 - 2020, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2022, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -186,6 +186,17 @@ extern "C" { #include "bli_util.h" +// -- addon definitions -- + +// NOTE: These definitions should not be included much earlier since an addon +// may wish to utilize other types and definitions provided by BLIS. +// TODO: Disable addon header file inclusion for windows since configure +// script is not executed, and subsequently the header file ie not generated. +#if !defined(_WIN32) && !defined(__CYGWIN__) +#include "bli_addon.h" +#endif + + // -- sandbox implementation -- #include "bli_sbox.h" diff --git a/frame/thread/bli_thread.c b/frame/thread/bli_thread.c index 097d136e7e..f721bae7e6 100644 --- a/frame/thread/bli_thread.c +++ b/frame/thread/bli_thread.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 21, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2018 - 22, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -1614,12 +1614,11 @@ void bli_thread_set_num_threads( dim_t n_threads ) bli_rntm_set_num_threads_only( n_threads, &global_rntm ); -#ifdef BLIS_ENABLE_OPENMP - // In the function bli_rntm_init_from_global() we extract n_threads - // using the API omp_get_max_threads(). Following step ensures that - // omp_get_max_threads returns the same value as set here. - omp_set_num_threads( n_threads ); -#endif + // BLIS_NUM_THREADS env variable or BLIS API to set the + // number of threads is used. Setting the blis_mt flag to TRUE + // so that OMP API or OMP env variables will not be of effect + // going forward. + bli_rntm_set_blis_mt_only(TRUE, &global_rntm); // Release the mutex protecting global_rntm. bli_pthread_mutex_unlock( &global_rntm_mutex ); @@ -1642,30 +1641,41 @@ void bli_thread_init_rntm_from_env #ifdef BLIS_ENABLE_MULTITHREADING - // Try to read BLIS_NUM_THREADS first. - nt = bli_env_get_var( "BLIS_NUM_THREADS", -1 ); - - -#ifdef BLIS_ENABLE_OPENMP - // Scenarios: - // 1. If BLIS_NUM_THREADS is set with valid value, set the nt using omp_set_num_threads(nt) - // so that this value can be fetched inside BLIS API as well. - // 2. If BLIS_NUM_THREADS is not set, then if Application is multithreaded and issued + // 1. If BLIS_NUM_THREADS is set with a valid value, same value + // will be used in the subsequent parallel regions unless + // bli_thread_set_num_threads() API is used by the Application + // to modify the desired number of threads during BLIS API execution. + // + // 2. Once BLIS_NUM_THREADS environment variable or bli_thread_set_num_threads(nt) + // API is used by the application, BLIS module would always give precedence to + // these values. BLIS API would not consider the values set using OpenMP API + // omp_set_num_threads(nt) API or OMP_NUM_THREADS environment variable. + // + // 3. If Application wants to allocate separate number of threads for BLIS API execution + // and application, Application can choose either BLIS_NUM_THREADS environement variable + // or bli_thread_set_num_threads(nt) API, to set the desired number of threads + // in BLIS API Execution. Application can use OpenMP APIs or environment variables for + // itself. + // + // 4. If BLIS_NUM_THREADS is not set, then if Application is multithreaded and issued // omp_set_num_threads(nt) with desired number of threads, // omp_get_max_threads() API will fetch the number of threads set earlier. - // 3. If BLIS_NUM_THREADS is not set, omp_set_num_threads(nt) is not called by the application, + // + // 5. If BLIS_NUM_THREADS is not set, omp_set_num_threads(nt) is not called by the application, // but only OMP_NUM_THREADS is set, // omp_get_max_threads() API will fetch the value of OMP_NUM_THREADS. - // 4. If both environment variables are not set, or if they are set with invalid values, and + // + // 6. If both environment variables are not set, or if they are set with invalid values, and // omp_set_num_threads(nt) is not issued by application, // omp_get_max_threads() API will return the number of the cores in the current context. // - // BLIS will rntm->num_threads will also get initialized with the same value. + // BLIS will initialize rntm->num_threads with the same value. // However if omp_set_nested is false - BLIS APIs called from parallel threads will run in sequential. // But if nested parallelism is enabled - Then each application will launch MT BLIS. // // Order of precedence used for number of threads: + // 0. valid value set using bli_thread_set_num_threads(nt) by the application // 1. valid value set for BLIS_NUM_THREADS environment variable // 2. omp_set_num_threads(nt) issued by the application // 3. valid value set for OMP_NUM_THREADS environment variable @@ -1676,16 +1686,27 @@ void bli_thread_init_rntm_from_env // // OMP_NUM_THREADS environment variable is applicable only when OpenMP is enabled. + + // Try to read BLIS_NUM_THREADS first. + nt = bli_env_get_var( "BLIS_NUM_THREADS", -1 ); + + // If BLIS_NUM_THREADS is set with a valid value, set the blis_mt flag in global runtime + // structure. Later during API execution, this flag will be checked for TRUE or FALSE. + // If the flag is FALSE, only then the value set by the application using OpenMP API, + // would be fetched and used subsequently. if(nt > 0) { - omp_set_num_threads(nt); + bli_rntm_set_blis_mt_only(TRUE, rntm); } else { + bli_rntm_set_blis_mt_only(FALSE, rntm); + +#ifdef BLIS_ENABLE_OPENMP nt = omp_get_max_threads(); +#endif } -#endif // Read the environment variables for the number of threads (ways // of parallelism) for each individual loop. jc = bli_env_get_var( "BLIS_JC_NT", -1 ); diff --git a/frame/util/bli_util_api_wrap.h b/frame/util/bli_util_api_wrap.h index 78f088e28e..86471c76f6 100644 --- a/frame/util/bli_util_api_wrap.h +++ b/frame/util/bli_util_api_wrap.h @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/frame/util/bli_util_progress.h b/frame/util/bli_util_progress.h index 0e2a63eb1c..ed7a79cb66 100644 --- a/frame/util/bli_util_progress.h +++ b/frame/util/bli_util_progress.h @@ -37,11 +37,11 @@ // Public interface for the end user. -typedef dim_t (*AOCL_progress_callback)(char *api, - dim_t lapi, - dim_t progress, - dim_t current_thread, - dim_t total_threads); +typedef dim_t (*AOCL_progress_callback)(const char* const api, + const dim_t lapi, + const dim_t progress, + const dim_t current_thread, + const dim_t total_threads); BLIS_EXPORT_BLIS void AOCL_BLIS_set_progress(AOCL_progress_callback func); diff --git a/frame/util/bli_util_unb_var1.c b/frame/util/bli_util_unb_var1.c index a2166b7b1f..78c4c9198d 100644 --- a/frame/util/bli_util_unb_var1.c +++ b/frame/util/bli_util_unb_var1.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2018 - 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -308,7 +308,55 @@ void PASTEMAC(ch,varname) \ //INSERT_GENTFUNCR_BASIC( normfv_unb_var1, sumsqv_unb_var1 ) GENTFUNCR( scomplex, float, c, s, normfv_unb_var1, sumsqv_unb_var1 ) -GENTFUNCR( dcomplex, double, z, d, normfv_unb_var1, sumsqv_unb_var1 ) + +void bli_znormfv_unb_var1 + ( + dim_t n, + dcomplex* x, + inc_t incx, + double* norm, + cntx_t* cntx, + rntm_t* rntm + ) +{ + + if ( bli_cpuid_is_avx_supported() == TRUE ) + { + bli_dznorm2fv_unb_var1_avx2( n, x, incx, norm, cntx ); + } + else + { + double* zero = bli_d0; + double* one = bli_d1; + double scale; + double sumsq; + double sqrt_sumsq; + + // Initialize scale and sumsq to begin the summation. + bli_dcopys( *zero, scale ); + bli_dcopys( *one, sumsq ); + + // Compute the sum of the squares of the vector. + + bli_zsumsqv_unb_var1 + ( + n, + x, + incx, + &scale, + &sumsq, + cntx, + rntm + ); + + // Compute: norm = scale * sqrt( sumsq ) + bli_dsqrt2s( sumsq, sqrt_sumsq ); + bli_dscals( scale, sqrt_sumsq ); + + // Store the final value to the output variable. + bli_dcopys( sqrt_sumsq, *norm ); + } +} #undef GENTFUNCR // We've disabled the dotv-based implementation because that method of @@ -440,34 +488,55 @@ void PASTEMAC(ch,varname) \ } #endif GENTFUNCR( float, float, s, s, normfv_unb_var1, sumsqv_unb_var1 ) -/*call sumsqv_unb_var1 if FAST_MATH is not defined else call dot-norm method*/\ -#ifndef BLIS_ENABLE_FAST_MATH -GENTFUNCR( double, double, d, d, normfv_unb_var1, sumsqv_unb_var1 ) -#else -#undef GENTFUNCR -#define GENTFUNCR( ctype, ctype_r, ch, chr, varname, kername ) \ -\ -void PASTEMAC(ch,varname) \ - ( \ - dim_t n, \ - ctype* x, inc_t incx, \ - ctype_r* norm, \ - cntx_t* cntx, \ - rntm_t* rntm \ - ) \ -{ \ -\ - /* Compute the sum of the squares of the vector. */ \ - PASTEMAC(ch,kername) \ - ( \ - n, \ - x, incx, \ - norm, \ - cntx \ - ); \ + +void bli_dnormfv_unb_var1 + ( + dim_t n, + double* x, + inc_t incx, + double* norm, + cntx_t* cntx, + rntm_t* rntm + ) +{ + + if( bli_cpuid_is_avx_supported() == TRUE ) + { + bli_dnorm2fv_unb_var1_avx2( n, x, incx, norm, cntx ); + } + else + { + double* zero = bli_d0; + double* one = bli_d1; + double scale; + double sumsq; + double sqrt_sumsq; + + // Initialize scale and sumsq to begin the summation. + bli_ddcopys( *zero, scale ); + bli_ddcopys( *one, sumsq ); + + // Compute the sum of the squares of the vector. + + bli_dsumsqv_unb_var1 + ( + n, + x, + incx, + &scale, + &sumsq, + cntx, + rntm + ); + + // Compute: norm = scale * sqrt( sumsq ) + bli_dsqrt2s( sumsq, sqrt_sumsq ); + bli_dscals( scale, sqrt_sumsq ); + + // Store the final value to the output variable. + bli_dcopys( sqrt_sumsq, *norm ); + } } -GENTFUNCR( double, double, d, d, normfv_unb_var1, norm2fv_unb_var1 ) -#endif #undef GENTFUNCR #define GENTFUNCR( ctype, ctype_r, ch, chr, varname ) \ diff --git a/frame/util/bli_util_update.c b/frame/util/bli_util_update.c index b57c065721..6bcd31dff2 100644 --- a/frame/util/bli_util_update.c +++ b/frame/util/bli_util_update.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2020 - 21, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020 - 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -212,12 +212,12 @@ void PASTEMAC(ch, varname) \ c[m*rs_c + n].imag = ct[m*rs_ct + n].imag; \ } \ \ - for(; m < m_cur; m++) \ - for(n = 0; n < n_cur; n++) \ - { \ - c[m*rs_c + n].real = ct[m*rs_ct + n].real; \ - c[m*rs_c + n].imag = ct[m*rs_ct + n].imag; \ - } \ + for(; m < m_cur; m++) \ + for(n = 0; n < n_cur; n++) \ + { \ + c[m*rs_c + n].real = ct[m*rs_ct + n].real; \ + c[m*rs_c + n].imag = ct[m*rs_ct + n].imag; \ + } \ } \ \ return; \ diff --git a/kernels/CMakeLists.txt b/kernels/CMakeLists.txt index 5cf469ef18..bee82f8685 100644 --- a/kernels/CMakeLists.txt +++ b/kernels/CMakeLists.txt @@ -1,4 +1,10 @@ -##Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved.## +##Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved.## add_subdirectory(haswell) add_subdirectory(zen) + +if(${TARGET_ARCH} STREQUAL zen4 OR + ${TARGET_ARCH} STREQUAL amdzen) + add_subdirectory(skx) + add_subdirectory(zen4) +endif() \ No newline at end of file diff --git a/kernels/haswell/1m/bli_packm_haswell_asm_c3xk.c b/kernels/haswell/1m/bli_packm_haswell_asm_c3xk.c index b99b6eef26..ab42e06aa9 100644 --- a/kernels/haswell/1m/bli_packm_haswell_asm_c3xk.c +++ b/kernels/haswell/1m/bli_packm_haswell_asm_c3xk.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2019 - 2021, Advanced Micro Devices, Inc.All rights reserved. + Copyright (C) 2019 - 2022, Advanced Micro Devices, Inc.All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -125,7 +125,7 @@ void bli_cpackm_haswell_asm_3xk mov(var(kappa), rcx) // load address of kappa vbroadcastss(mem(rcx, 0), ymm10) // load kappa_r and duplicate - vbroadcastss(mem(rcx, 8), ymm11) // load kappa_i and duplicate + vbroadcastss(mem(rcx, 4), ymm11) // load kappa_i and duplicate // now branch on kappa == 1.0 diff --git a/kernels/haswell/1m/bli_packm_haswell_asm_c8xk.c b/kernels/haswell/1m/bli_packm_haswell_asm_c8xk.c index 4cad0c90c3..a101e66d18 100644 --- a/kernels/haswell/1m/bli_packm_haswell_asm_c8xk.c +++ b/kernels/haswell/1m/bli_packm_haswell_asm_c8xk.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2019 - 2021, Advanced Micro Devices, Inc. + Copyright (C) 2019 - 2022, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -125,7 +125,7 @@ void bli_cpackm_haswell_asm_8xk mov(var(kappa), rcx) // load address of kappa vbroadcastss(mem(rcx, 0), ymm10) // load kappa_r and duplicate - vbroadcastss(mem(rcx, 8), ymm11) // load kappa_i and duplicate + vbroadcastss(mem(rcx, 4), ymm11) // load kappa_i and duplicate // now branch on kappa == 1.0 diff --git a/kernels/haswell/1m/bli_packm_haswell_asm_d8xk.c b/kernels/haswell/1m/bli_packm_haswell_asm_d8xk.c index 9deb564ce4..0cfa2e8d68 100644 --- a/kernels/haswell/1m/bli_packm_haswell_asm_d8xk.c +++ b/kernels/haswell/1m/bli_packm_haswell_asm_d8xk.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2019 - 2020, Advanced Micro Devices, Inc. + Copyright (C) 2019 - 2022, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -101,6 +101,8 @@ void bli_dpackm_haswell_asm_8xk // assembly region, this constraint should be lifted. const bool unitk = bli_deq1( *kappa ); + double* restrict a_next = a + cdim0; + // ------------------------------------------------------------------------- @@ -267,7 +269,7 @@ void bli_dpackm_haswell_asm_8xk label(.DCOLUNIT) lea(mem(r10, r10, 2), r13) // r13 = 3*lda - + mov(var(a_next), rcx) mov(var(k_iter), rsi) // i = k_iter; test(rsi, rsi) // check i via logical AND. je(.DCONKLEFTCOLU) // if i == 0, jump to code that @@ -278,22 +280,27 @@ void bli_dpackm_haswell_asm_8xk vmovupd(mem(rax, 0), ymm0) vmovupd(mem(rax, 32), ymm1) + prefetch(0, mem(rcx,7*8)) vmovupd(ymm0, mem(rbx, 0*64+ 0)) vmovupd(ymm1, mem(rbx, 0*64+32)) vmovupd(mem(rax, r10, 1, 0), ymm2) vmovupd(mem(rax, r10, 1, 32), ymm3) + prefetch(0, mem(rcx, r10, 1,7*8)) vmovupd(ymm2, mem(rbx, 1*64+ 0)) vmovupd(ymm3, mem(rbx, 1*64+32)) vmovupd(mem(rax, r10, 2, 0), ymm4) vmovupd(mem(rax, r10, 2, 32), ymm5) + prefetch(0, mem(rcx, r10, 2,7*8)) vmovupd(ymm4, mem(rbx, 2*64+ 0)) vmovupd(ymm5, mem(rbx, 2*64+32)) vmovupd(mem(rax, r13, 1, 0), ymm6) vmovupd(mem(rax, r13, 1, 32), ymm7) + prefetch(0, mem(rcx, r13, 1,7*8)) add(r14, rax) // a += 4*lda; + add(r14, rcx) vmovupd(ymm6, mem(rbx, 3*64+ 0)) vmovupd(ymm7, mem(rbx, 3*64+32)) add(imm(4*8*8), rbx) // p += 4*ldp = 4*8; @@ -315,7 +322,9 @@ void bli_dpackm_haswell_asm_8xk vmovupd(mem(rax, 0), ymm0) vmovupd(mem(rax, 32), ymm1) + prefetch(0, mem(rcx,7*8)) add(r10, rax) // a += lda; + add(r10, rcx) vmovupd(ymm0, mem(rbx, 0*64+ 0)) vmovupd(ymm1, mem(rbx, 0*64+32)) add(imm(8*8), rbx) // p += ldp = 8; @@ -343,7 +352,8 @@ void bli_dpackm_haswell_asm_8xk [p] "m" (p), [ldp] "m" (ldp), [kappa] "m" (kappa), - [one] "m" (one) + [one] "m" (one), + [a_next] "m" (a_next) : // register clobber list "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "r8", /*"r9",*/ "r10", /*"r11",*/ "r12", "r13", "r14", "r15", diff --git a/kernels/haswell/3/bli_gemm_haswell_asm_d6x8.c b/kernels/haswell/3/bli_gemm_haswell_asm_d6x8.c index b4ac979e1a..79625519c5 100644 --- a/kernels/haswell/3/bli_gemm_haswell_asm_d6x8.c +++ b/kernels/haswell/3/bli_gemm_haswell_asm_d6x8.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2021, Advanced Micro Devices, Inc.All rights reserved. + Copyright (C) 2018 - 2022, Advanced Micro Devices, Inc.All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -102,7 +102,19 @@ void bli_sgemm_haswell_asm_6x16 begin_asm() - vzeroall() // zero all xmm/ymm registers. + //vzeroall() // zero all xmm/ymm registers. + vxorps( ymm4, ymm4, ymm4) + vmovaps( ymm4, ymm5) + vmovaps( ymm4, ymm6) + vmovaps( ymm4, ymm7) + vmovaps( ymm4, ymm8) + vmovaps( ymm4, ymm9) + vmovaps( ymm4, ymm10) + vmovaps( ymm4, ymm11) + vmovaps( ymm4, ymm12) + vmovaps( ymm4, ymm13) + vmovaps( ymm4, ymm14) + vmovaps( ymm4, ymm15) mov(var(a), rax) // load address of a. @@ -141,7 +153,7 @@ void bli_sgemm_haswell_asm_6x16 // iteration 0 prefetch(0, mem(rax, 64*4)) - + vbroadcastss(mem(rax, 0*4), ymm2) vbroadcastss(mem(rax, 1*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm4) @@ -167,6 +179,8 @@ void bli_sgemm_haswell_asm_6x16 vmovaps(mem(rbx, -1*32), ymm1) // iteration 1 + prefetch(0, mem(rax, 72*4)) + vbroadcastss(mem(rax, 6*4), ymm2) vbroadcastss(mem(rax, 7*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm4) @@ -192,7 +206,7 @@ void bli_sgemm_haswell_asm_6x16 vmovaps(mem(rbx, 1*32), ymm1) // iteration 2 - prefetch(0, mem(rax, 76*4)) + prefetch(0, mem(rax, 80*4)) vbroadcastss(mem(rax, 12*4), ymm2) vbroadcastss(mem(rax, 13*4), ymm3) @@ -870,7 +884,7 @@ void bli_sgemm_haswell_asm_6x16 label(.SDONE) - + vzeroupper() end_asm( : // output operands (none) @@ -950,7 +964,21 @@ void bli_dgemm_haswell_asm_6x8 begin_asm() - vzeroall() // zero all xmm/ymm registers. + //vzeroall() // zero all xmm/ymm registers. + + vxorpd( ymm4, ymm4, ymm4) // vzeroall is expensive + vmovapd( ymm4, ymm5) + vmovapd( ymm4, ymm6) + vmovapd( ymm4, ymm7) + vmovapd( ymm4, ymm8) + vmovapd( ymm4, ymm9) + vmovapd( ymm4, ymm10) + vmovapd( ymm4, ymm11) + vmovapd( ymm4, ymm12) + vmovapd( ymm4, ymm13) + vmovapd( ymm4, ymm14) + vmovapd( ymm4, ymm15) + mov(var(a), rax) // load address of a. @@ -996,76 +1024,78 @@ void bli_dgemm_haswell_asm_6x8 vfmadd231pd(ymm1, ymm2, ymm5) vfmadd231pd(ymm0, ymm3, ymm6) vfmadd231pd(ymm1, ymm3, ymm7) - + vbroadcastsd(mem(rax, 2*8), ymm2) vbroadcastsd(mem(rax, 3*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm8) vfmadd231pd(ymm1, ymm2, ymm9) vfmadd231pd(ymm0, ymm3, ymm10) vfmadd231pd(ymm1, ymm3, ymm11) - + vbroadcastsd(mem(rax, 4*8), ymm2) vbroadcastsd(mem(rax, 5*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm12) vfmadd231pd(ymm1, ymm2, ymm13) vfmadd231pd(ymm0, ymm3, ymm14) vfmadd231pd(ymm1, ymm3, ymm15) - + vmovapd(mem(rbx, -2*32), ymm0) vmovapd(mem(rbx, -1*32), ymm1) - + // iteration 1 + prefetch(0, mem(rax, 72*8)) + vbroadcastsd(mem(rax, 6*8), ymm2) vbroadcastsd(mem(rax, 7*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm4) vfmadd231pd(ymm1, ymm2, ymm5) vfmadd231pd(ymm0, ymm3, ymm6) vfmadd231pd(ymm1, ymm3, ymm7) - + vbroadcastsd(mem(rax, 8*8), ymm2) vbroadcastsd(mem(rax, 9*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm8) vfmadd231pd(ymm1, ymm2, ymm9) vfmadd231pd(ymm0, ymm3, ymm10) vfmadd231pd(ymm1, ymm3, ymm11) - + vbroadcastsd(mem(rax, 10*8), ymm2) vbroadcastsd(mem(rax, 11*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm12) vfmadd231pd(ymm1, ymm2, ymm13) vfmadd231pd(ymm0, ymm3, ymm14) vfmadd231pd(ymm1, ymm3, ymm15) - + vmovapd(mem(rbx, 0*32), ymm0) vmovapd(mem(rbx, 1*32), ymm1) - + // iteration 2 - prefetch(0, mem(rax, 76*8)) - + prefetch(0, mem(rax, 80*8)) + vbroadcastsd(mem(rax, 12*8), ymm2) vbroadcastsd(mem(rax, 13*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm4) vfmadd231pd(ymm1, ymm2, ymm5) vfmadd231pd(ymm0, ymm3, ymm6) vfmadd231pd(ymm1, ymm3, ymm7) - + vbroadcastsd(mem(rax, 14*8), ymm2) vbroadcastsd(mem(rax, 15*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm8) vfmadd231pd(ymm1, ymm2, ymm9) vfmadd231pd(ymm0, ymm3, ymm10) vfmadd231pd(ymm1, ymm3, ymm11) - + vbroadcastsd(mem(rax, 16*8), ymm2) vbroadcastsd(mem(rax, 17*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm12) vfmadd231pd(ymm1, ymm2, ymm13) vfmadd231pd(ymm0, ymm3, ymm14) vfmadd231pd(ymm1, ymm3, ymm15) - + vmovapd(mem(rbx, 2*32), ymm0) vmovapd(mem(rbx, 3*32), ymm1) - + // iteration 3 vbroadcastsd(mem(rax, 18*8), ymm2) vbroadcastsd(mem(rax, 19*8), ymm3) @@ -1073,91 +1103,91 @@ void bli_dgemm_haswell_asm_6x8 vfmadd231pd(ymm1, ymm2, ymm5) vfmadd231pd(ymm0, ymm3, ymm6) vfmadd231pd(ymm1, ymm3, ymm7) - + vbroadcastsd(mem(rax, 20*8), ymm2) vbroadcastsd(mem(rax, 21*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm8) vfmadd231pd(ymm1, ymm2, ymm9) vfmadd231pd(ymm0, ymm3, ymm10) vfmadd231pd(ymm1, ymm3, ymm11) - + vbroadcastsd(mem(rax, 22*8), ymm2) vbroadcastsd(mem(rax, 23*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm12) vfmadd231pd(ymm1, ymm2, ymm13) vfmadd231pd(ymm0, ymm3, ymm14) vfmadd231pd(ymm1, ymm3, ymm15) - + add(imm(4*6*8), rax) // a += 4*6 (unroll x mr) add(imm(4*8*8), rbx) // b += 4*8 (unroll x nr) - + vmovapd(mem(rbx, -4*32), ymm0) vmovapd(mem(rbx, -3*32), ymm1) - - + + dec(rsi) // i -= 1; jne(.DLOOPKITER) // iterate again if i != 0. - - - - - - + + + + + + label(.DCONSIDKLEFT) - + mov(var(k_left), rsi) // i = k_left; test(rsi, rsi) // check i via logical AND. je(.DPOSTACCUM) // if i == 0, we're done; jump to end. // else, we prepare to enter k_left loop. - - + + label(.DLOOPKLEFT) // EDGE LOOP - + prefetch(0, mem(rax, 64*8)) - + vbroadcastsd(mem(rax, 0*8), ymm2) vbroadcastsd(mem(rax, 1*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm4) vfmadd231pd(ymm1, ymm2, ymm5) vfmadd231pd(ymm0, ymm3, ymm6) vfmadd231pd(ymm1, ymm3, ymm7) - + vbroadcastsd(mem(rax, 2*8), ymm2) vbroadcastsd(mem(rax, 3*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm8) vfmadd231pd(ymm1, ymm2, ymm9) vfmadd231pd(ymm0, ymm3, ymm10) vfmadd231pd(ymm1, ymm3, ymm11) - + vbroadcastsd(mem(rax, 4*8), ymm2) vbroadcastsd(mem(rax, 5*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm12) vfmadd231pd(ymm1, ymm2, ymm13) vfmadd231pd(ymm0, ymm3, ymm14) vfmadd231pd(ymm1, ymm3, ymm15) - + add(imm(1*6*8), rax) // a += 1*6 (unroll x mr) add(imm(1*8*8), rbx) // b += 1*8 (unroll x nr) - + vmovapd(mem(rbx, -4*32), ymm0) vmovapd(mem(rbx, -3*32), ymm1) - - + + dec(rsi) // i -= 1; jne(.DLOOPKLEFT) // iterate again if i != 0. - - - + + + label(.DPOSTACCUM) - - - - + + + + mov(var(alpha), rax) // load address of alpha mov(var(beta), rbx) // load address of beta vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate - + vmulpd(ymm0, ymm4, ymm4) // scale by alpha vmulpd(ymm0, ymm5, ymm5) vmulpd(ymm0, ymm6, ymm6) @@ -1170,179 +1200,179 @@ void bli_dgemm_haswell_asm_6x8 vmulpd(ymm0, ymm13, ymm13) vmulpd(ymm0, ymm14, ymm14) vmulpd(ymm0, ymm15, ymm15) - - - - - - + + + + + + mov(var(cs_c), rsi) // load cs_c lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) - + lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; lea(mem(rcx, rdi, 4), r14) // load address of c + 4*rs_c; - + lea(mem(rsi, rsi, 2), r13) // r13 = 3*cs_c; //lea(mem(rsi, rsi, 4), r15) // r15 = 5*cs_c; //lea(mem(r13, rsi, 4), r10) // r10 = 7*cs_c; - - + + // now avoid loading C if beta == 0 - + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. vucomisd(xmm0, xmm3) // set ZF if beta == 0. je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case - - + + cmp(imm(8), rsi) // set ZF if (8*cs_c) == 8. jz(.DROWSTORED) // jump to row storage case - - + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. jz(.DCOLSTORED) // jump to column storage case - - - + + + label(.DGENSTORED) - - + + DGEMM_INPUT_GS_BETA_NZ vfmadd213pd(ymm4, ymm3, ymm0) DGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c += rs_c; - - + + DGEMM_INPUT_GS_BETA_NZ vfmadd213pd(ymm6, ymm3, ymm0) DGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c += rs_c; - - + + DGEMM_INPUT_GS_BETA_NZ vfmadd213pd(ymm8, ymm3, ymm0) DGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c += rs_c; - - + + DGEMM_INPUT_GS_BETA_NZ vfmadd213pd(ymm10, ymm3, ymm0) DGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c += rs_c; - - + + DGEMM_INPUT_GS_BETA_NZ vfmadd213pd(ymm12, ymm3, ymm0) DGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c += rs_c; - - + + DGEMM_INPUT_GS_BETA_NZ vfmadd213pd(ymm14, ymm3, ymm0) DGEMM_OUTPUT_GS_BETA_NZ - - + + mov(rdx, rcx) // rcx = c + 4*cs_c - - + + DGEMM_INPUT_GS_BETA_NZ vfmadd213pd(ymm5, ymm3, ymm0) DGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c += rs_c; - - + + DGEMM_INPUT_GS_BETA_NZ vfmadd213pd(ymm7, ymm3, ymm0) DGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c += rs_c; - - + + DGEMM_INPUT_GS_BETA_NZ vfmadd213pd(ymm9, ymm3, ymm0) DGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c += rs_c; - - + + DGEMM_INPUT_GS_BETA_NZ vfmadd213pd(ymm11, ymm3, ymm0) DGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c += rs_c; - - + + DGEMM_INPUT_GS_BETA_NZ vfmadd213pd(ymm13, ymm3, ymm0) DGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c += rs_c; - - + + DGEMM_INPUT_GS_BETA_NZ vfmadd213pd(ymm15, ymm3, ymm0) DGEMM_OUTPUT_GS_BETA_NZ - - - + + + jmp(.DDONE) // jump to end. - - - + + + label(.DROWSTORED) - - + + vfmadd231pd(mem(rcx), ymm3, ymm4) vmovupd(ymm4, mem(rcx)) add(rdi, rcx) vfmadd231pd(mem(rdx), ymm3, ymm5) vmovupd(ymm5, mem(rdx)) add(rdi, rdx) - - + + vfmadd231pd(mem(rcx), ymm3, ymm6) vmovupd(ymm6, mem(rcx)) add(rdi, rcx) vfmadd231pd(mem(rdx), ymm3, ymm7) vmovupd(ymm7, mem(rdx)) add(rdi, rdx) - - + + vfmadd231pd(mem(rcx), ymm3, ymm8) vmovupd(ymm8, mem(rcx)) add(rdi, rcx) vfmadd231pd(mem(rdx), ymm3, ymm9) vmovupd(ymm9, mem(rdx)) add(rdi, rdx) - - + + vfmadd231pd(mem(rcx), ymm3, ymm10) vmovupd(ymm10, mem(rcx)) add(rdi, rcx) vfmadd231pd(mem(rdx), ymm3, ymm11) vmovupd(ymm11, mem(rdx)) add(rdi, rdx) - - + + vfmadd231pd(mem(rcx), ymm3, ymm12) vmovupd(ymm12, mem(rcx)) add(rdi, rcx) vfmadd231pd(mem(rdx), ymm3, ymm13) vmovupd(ymm13, mem(rdx)) add(rdi, rdx) - - + + vfmadd231pd(mem(rcx), ymm3, ymm14) vmovupd(ymm14, mem(rcx)) //add(rdi, rcx) vfmadd231pd(mem(rdx), ymm3, ymm15) vmovupd(ymm15, mem(rdx)) //add(rdi, rdx) - - - + + + jmp(.DDONE) // jump to end. - - - + + + label(.DCOLSTORED) - - + + vunpcklpd(ymm6, ymm4, ymm0) vunpckhpd(ymm6, ymm4, ymm1) vunpcklpd(ymm10, ymm8, ymm2) @@ -1351,9 +1381,9 @@ void bli_dgemm_haswell_asm_6x8 vinsertf128(imm(0x1), xmm3, ymm1, ymm6) vperm2f128(imm(0x31), ymm2, ymm0, ymm8) vperm2f128(imm(0x31), ymm3, ymm1, ymm10) - + vbroadcastsd(mem(rbx), ymm3) - + vfmadd231pd(mem(rcx), ymm3, ymm4) vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm6) vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm8) @@ -1362,14 +1392,14 @@ void bli_dgemm_haswell_asm_6x8 vmovupd(ymm6, mem(rcx, rsi, 1)) vmovupd(ymm8, mem(rcx, rsi, 2)) vmovupd(ymm10, mem(rcx, r13, 1)) - + lea(mem(rcx, rsi, 4), rcx) - + vunpcklpd(ymm14, ymm12, ymm0) vunpckhpd(ymm14, ymm12, ymm1) vextractf128(imm(0x1), ymm0, xmm2) vextractf128(imm(0x1), ymm1, xmm4) - + vfmadd231pd(mem(r14), xmm3, xmm0) vfmadd231pd(mem(r14, rsi, 1), xmm3, xmm1) vfmadd231pd(mem(r14, rsi, 2), xmm3, xmm2) @@ -1378,10 +1408,10 @@ void bli_dgemm_haswell_asm_6x8 vmovupd(xmm1, mem(r14, rsi, 1)) vmovupd(xmm2, mem(r14, rsi, 2)) vmovupd(xmm4, mem(r14, r13, 1)) - + lea(mem(r14, rsi, 4), r14) - - + + vunpcklpd(ymm7, ymm5, ymm0) vunpckhpd(ymm7, ymm5, ymm1) vunpcklpd(ymm11, ymm9, ymm2) @@ -1390,9 +1420,9 @@ void bli_dgemm_haswell_asm_6x8 vinsertf128(imm(0x1), xmm3, ymm1, ymm7) vperm2f128(imm(0x31), ymm2, ymm0, ymm9) vperm2f128(imm(0x31), ymm3, ymm1, ymm11) - + vbroadcastsd(mem(rbx), ymm3) - + vfmadd231pd(mem(rcx), ymm3, ymm5) vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm7) vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm9) @@ -1401,14 +1431,14 @@ void bli_dgemm_haswell_asm_6x8 vmovupd(ymm7, mem(rcx, rsi, 1)) vmovupd(ymm9, mem(rcx, rsi, 2)) vmovupd(ymm11, mem(rcx, r13, 1)) - + //lea(mem(rcx, rsi, 4), rcx) - + vunpcklpd(ymm15, ymm13, ymm0) vunpckhpd(ymm15, ymm13, ymm1) vextractf128(imm(0x1), ymm0, xmm2) vextractf128(imm(0x1), ymm1, xmm4) - + vfmadd231pd(mem(r14), xmm3, xmm0) vfmadd231pd(mem(r14, rsi, 1), xmm3, xmm1) vfmadd231pd(mem(r14, rsi, 2), xmm3, xmm2) @@ -1417,139 +1447,139 @@ void bli_dgemm_haswell_asm_6x8 vmovupd(xmm1, mem(r14, rsi, 1)) vmovupd(xmm2, mem(r14, rsi, 2)) vmovupd(xmm4, mem(r14, r13, 1)) - + //lea(mem(r14, rsi, 4), r14) - - - + + + jmp(.DDONE) // jump to end. - - - + + + label(.DBETAZERO) - + cmp(imm(8), rsi) // set ZF if (8*cs_c) == 8. jz(.DROWSTORBZ) // jump to row storage case - + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. jz(.DCOLSTORBZ) // jump to column storage case - - - + + + label(.DGENSTORBZ) - - + + vmovapd(ymm4, ymm0) DGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c += rs_c; - - + + vmovapd(ymm6, ymm0) DGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c += rs_c; - - + + vmovapd(ymm8, ymm0) DGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c += rs_c; - - + + vmovapd(ymm10, ymm0) DGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c += rs_c; - - + + vmovapd(ymm12, ymm0) DGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c += rs_c; - - + + vmovapd(ymm14, ymm0) DGEMM_OUTPUT_GS_BETA_NZ - - + + mov(rdx, rcx) // rcx = c + 4*cs_c - - + + vmovapd(ymm5, ymm0) DGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c += rs_c; - - + + vmovapd(ymm7, ymm0) DGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c += rs_c; - - + + vmovapd(ymm9, ymm0) DGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c += rs_c; - - + + vmovapd(ymm11, ymm0) DGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c += rs_c; - - + + vmovapd(ymm13, ymm0) DGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c += rs_c; - - + + vmovapd(ymm15, ymm0) DGEMM_OUTPUT_GS_BETA_NZ - - - + + + jmp(.DDONE) // jump to end. - - - + + + label(.DROWSTORBZ) - - + + vmovupd(ymm4, mem(rcx)) add(rdi, rcx) vmovupd(ymm5, mem(rdx)) add(rdi, rdx) - + vmovupd(ymm6, mem(rcx)) add(rdi, rcx) vmovupd(ymm7, mem(rdx)) add(rdi, rdx) - - + + vmovupd(ymm8, mem(rcx)) add(rdi, rcx) vmovupd(ymm9, mem(rdx)) add(rdi, rdx) - - + + vmovupd(ymm10, mem(rcx)) add(rdi, rcx) vmovupd(ymm11, mem(rdx)) add(rdi, rdx) - - + + vmovupd(ymm12, mem(rcx)) add(rdi, rcx) vmovupd(ymm13, mem(rdx)) add(rdi, rdx) - - + + vmovupd(ymm14, mem(rcx)) //add(rdi, rcx) vmovupd(ymm15, mem(rdx)) //add(rdi, rdx) - - + + jmp(.DDONE) // jump to end. - - - + + + label(.DCOLSTORBZ) - - + + vunpcklpd(ymm6, ymm4, ymm0) vunpckhpd(ymm6, ymm4, ymm1) vunpcklpd(ymm10, ymm8, ymm2) @@ -1558,27 +1588,27 @@ void bli_dgemm_haswell_asm_6x8 vinsertf128(imm(0x1), xmm3, ymm1, ymm6) vperm2f128(imm(0x31), ymm2, ymm0, ymm8) vperm2f128(imm(0x31), ymm3, ymm1, ymm10) - + vmovupd(ymm4, mem(rcx)) vmovupd(ymm6, mem(rcx, rsi, 1)) vmovupd(ymm8, mem(rcx, rsi, 2)) vmovupd(ymm10, mem(rcx, r13, 1)) - + lea(mem(rcx, rsi, 4), rcx) - + vunpcklpd(ymm14, ymm12, ymm0) vunpckhpd(ymm14, ymm12, ymm1) vextractf128(imm(0x1), ymm0, xmm2) vextractf128(imm(0x1), ymm1, xmm4) - + vmovupd(xmm0, mem(r14)) vmovupd(xmm1, mem(r14, rsi, 1)) vmovupd(xmm2, mem(r14, rsi, 2)) vmovupd(xmm4, mem(r14, r13, 1)) - + lea(mem(r14, rsi, 4), r14) - - + + vunpcklpd(ymm7, ymm5, ymm0) vunpckhpd(ymm7, ymm5, ymm1) vunpcklpd(ymm11, ymm9, ymm2) @@ -1587,31 +1617,31 @@ void bli_dgemm_haswell_asm_6x8 vinsertf128(imm(0x1), xmm3, ymm1, ymm7) vperm2f128(imm(0x31), ymm2, ymm0, ymm9) vperm2f128(imm(0x31), ymm3, ymm1, ymm11) - + vmovupd(ymm5, mem(rcx)) vmovupd(ymm7, mem(rcx, rsi, 1)) vmovupd(ymm9, mem(rcx, rsi, 2)) vmovupd(ymm11, mem(rcx, r13, 1)) - + //lea(mem(rcx, rsi, 4), rcx) - + vunpcklpd(ymm15, ymm13, ymm0) vunpckhpd(ymm15, ymm13, ymm1) vextractf128(imm(0x1), ymm0, xmm2) vextractf128(imm(0x1), ymm1, xmm4) - + vmovupd(xmm0, mem(r14)) vmovupd(xmm1, mem(r14, rsi, 1)) vmovupd(xmm2, mem(r14, rsi, 2)) vmovupd(xmm4, mem(r14, r13, 1)) - + //lea(mem(r14, rsi, 4), r14) - - - + + + label(.DDONE) - - + vzeroupper() + end_asm( : // output operands (none) @@ -2144,7 +2174,7 @@ void bli_cgemm_haswell_asm_3x8 label(.CDONE) - + vzeroupper() end_asm( : // output operands (none) @@ -2744,7 +2774,7 @@ void bli_zgemm_haswell_asm_3x4 label(.ZDONE) - + vzeroupper() end_asm( : // output operands (none) diff --git a/kernels/haswell/3/sup/bli_gemmsup_rd_haswell_asm_d6x8m.c b/kernels/haswell/3/sup/bli_gemmsup_rd_haswell_asm_d6x8m.c index f6edad70bf..990358db8b 100644 --- a/kernels/haswell/3/sup/bli_gemmsup_rd_haswell_asm_d6x8m.c +++ b/kernels/haswell/3/sup/bli_gemmsup_rd_haswell_asm_d6x8m.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2019, Advanced Micro Devices, Inc. + Copyright (C) 2022, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -258,27 +258,12257 @@ void bli_dgemmsup_rd_haswell_asm_6x8m prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c #endif + lea(mem(r8, r8, 4), rcx) // rcx = 5*rs_a + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rcx, 1, 0*8)) // prefetch rax + 5*rs_a +#endif + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rcx, 1, 0*8)) // prefetch rax + 5*rs_a +#endif + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + +#if 1 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rcx, 1, 0*8)) // prefetch rax + 5*rs_a +#endif + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax ), xmm0) + vmovsd(mem(rax, r8, 1), xmm1) + vmovsd(mem(rax, r8, 2), xmm2) + add(imm(1*8), rax) // a += 1*cs_a = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + // xmm4[0] = sum(ymm4); xmm4[1] = sum(ymm7) + // xmm4[2] = sum(ymm10); xmm4[3] = sum(ymm13) + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + // xmm5[0] = sum(ymm5); xmm5[1] = sum(ymm8) + // xmm5[2] = sum(ymm11); xmm5[3] = sum(ymm14) + + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + // xmm6[0] = sum(ymm6); xmm6[1] = sum(ymm9) + // xmm6[2] = sum(ymm12); xmm6[3] = sum(ymm15) + + + + + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + lea(mem(r12), rcx) // rcx = c_iijj; + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + + lea(mem(r12, rdi, 2), r12) // + lea(mem(r12, rdi, 1), r12) // c_ii = r12 += 3*rs_c + + lea(mem(r14, r8, 2), r14) // + lea(mem(r14, r8, 1), r14) // a_ii = r14 += 3*rs_a + + dec(r9) // ii -= 1; + jne(.DLOOP3X4I) // iterate again if ii != 0. + + + + + add(imm(4), r15) // jj += 4; + cmp(imm(8), r15) // compare jj to 8 + jl(.DLOOP3X4J) // if jj < 8, jump to beginning + // of jj loop; otherwise, loop ends. + + + + label(.DRETURN) + + vzeroupper() + + end_asm( + : // output operands (none) + : // input operands + [m_iter] "m" (m_iter), + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + consider_edge_cases: + + // Handle edge cases in the m dimension, if they exist. + if ( m_left ) + { + const dim_t nr_cur = 8; + const dim_t i_edge = m0 - ( dim_t )m_left; + + double* restrict cij = c + i_edge*rs_c; + double* restrict bj = b; + double* restrict ai = a + i_edge*rs_a; + + if ( 2 == m_left ) + { + const dim_t mr_cur = 2; + + bli_dgemmsup_rd_haswell_asm_2x8 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + //cij += mr_cur*rs_c0; ai += mr_cur*rs_a0; m_left -= mr_cur; + } + if ( 1 == m_left ) + { + const dim_t mr_cur = 1; + + bli_dgemmsup_rd_haswell_asm_1x8 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + } + } + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7); +} + +/* +24x24 block + + 1 1 1 1 1 1 1 1 1 1 2 2 2 2 + 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 + |- - - - - - - -|- - - - - - - -| - - - - - - - -| +0 | | | | +1 | m_off_24 = 0 | | | +2 | n_off_24 = 0 | | | +3 | m_idx = 0 | | | +4 | n_idx = 0 | | | +5 |- - - - - - - -|- - - - - - - -|- - - - - - - - | +6 | | | | +7 | m_off_24 = 6 | m_off_24 = 6 | | +8 | n_off_24 = 0 | n_off_24 = 8 | | +9 | m_idx = 1 | m_idx = 1 | | +10 | n_idx = 0 | n_idx = 1 | | +11 |- - - - - - - -|- - - - - - - -|- - - - - - - - | +12 | | | | +13 | | m_off_24 = 12 | m_off_24 = 12 | +14 | | n_off_24 = 8 | n_off_24 = 16 | +15 | | m_idx = 2 | m_idx = 2 | +16 | | n_idx = 1 | n_idx = 2 | +17 |- - - - - - - -|- - - - - - - -|- - - - - - - - | +18 | | | | +19 | | | m_off_24 = 18 | +20 | | | n_off_24 = 16 | +21 | | | m_idx = 3 | +22 | | | n_idx = 2 | +23 |- - - - - - - -|- - - - - - - -|- - - - - - - - | +*/ + + +#define SUBITER_K4_3x4(a, b) \ +\ + vmovupd(mem(a ), ymm0) \ + vmovupd(mem(a, r8, 1), ymm1) \ + vmovupd(mem(a, r8, 2), ymm2) \ + add(imm(4*8), a) \ +\ + vmovupd(mem(b ), ymm3) \ + vfmadd231pd(ymm0, ymm3, ymm4) \ + vfmadd231pd(ymm1, ymm3, ymm5) \ + vfmadd231pd(ymm2, ymm3, ymm6) \ +\ + vmovupd(mem(b, r11, 1), ymm3) \ + vfmadd231pd(ymm0, ymm3, ymm7) \ + vfmadd231pd(ymm1, ymm3, ymm8) \ + vfmadd231pd(ymm2, ymm3, ymm9) \ +\ + vmovupd(mem(b, r11, 2), ymm3) \ + vfmadd231pd(ymm0, ymm3, ymm10) \ + vfmadd231pd(ymm1, ymm3, ymm11) \ + vfmadd231pd(ymm2, ymm3, ymm12) \ +\ + vmovupd(mem(b, r13, 1), ymm3) \ + add(imm(4*8), b) \ + vfmadd231pd(ymm0, ymm3, ymm13) \ + vfmadd231pd(ymm1, ymm3, ymm14) \ + vfmadd231pd(ymm2, ymm3, ymm15) \ + +#define SUBITER_K1_3x4(a, b) \ +\ + vmovsd(mem(a ), xmm0) \ + vmovsd(mem(a, r8, 1), xmm1) \ + vmovsd(mem(a, r8, 2), xmm2) \ + add(imm(1*8), a) \ +\ + vmovsd(mem(b ), xmm3) \ + vfmadd231pd(ymm0, ymm3, ymm4) \ + vfmadd231pd(ymm1, ymm3, ymm5) \ + vfmadd231pd(ymm2, ymm3, ymm6) \ +\ + vmovsd(mem(b, r11, 1), xmm3) \ + vfmadd231pd(ymm0, ymm3, ymm7) \ + vfmadd231pd(ymm1, ymm3, ymm8) \ + vfmadd231pd(ymm2, ymm3, ymm9) \ +\ + vmovsd(mem(b, r11, 2), xmm3) \ + vfmadd231pd(ymm0, ymm3, ymm10) \ + vfmadd231pd(ymm1, ymm3, ymm11) \ + vfmadd231pd(ymm2, ymm3, ymm12) \ +\ + vmovsd(mem(b, r13, 1), xmm3) \ + add(imm(1*8), b) \ + vfmadd231pd(ymm0, ymm3, ymm13) \ + vfmadd231pd(ymm1, ymm3, ymm14) \ + vfmadd231pd(ymm2, ymm3, ymm15) \ + +#define SUBITER_K4_2x4(a, b) \ +\ + vmovupd(mem(a ), ymm0) \ + vmovupd(mem(a, r8, 1), ymm1) \ + add(imm(4*8), a) \ +\ + vmovupd(mem(b ), ymm3) \ + vfmadd231pd(ymm0, ymm3, ymm4) \ + vfmadd231pd(ymm1, ymm3, ymm5) \ +\ + vmovupd(mem(b, r11, 1), ymm3) \ + vfmadd231pd(ymm0, ymm3, ymm7) \ + vfmadd231pd(ymm1, ymm3, ymm8) \ +\ + vmovupd(mem(b, r11, 2), ymm3) \ + vfmadd231pd(ymm0, ymm3, ymm10) \ + vfmadd231pd(ymm1, ymm3, ymm11) \ +\ + vmovupd(mem(b, r13, 1), ymm3) \ + add(imm(4*8), b) \ + vfmadd231pd(ymm0, ymm3, ymm13) \ + vfmadd231pd(ymm1, ymm3, ymm14) \ + +#define SUBITER_K1_2x4(a, b) \ +\ + vmovsd(mem(a ), xmm0) \ + vmovsd(mem(a, r8, 1), xmm1) \ + add(imm(1*8), a) \ +\ + vmovsd(mem(b ), xmm3) \ + vfmadd231pd(ymm0, ymm3, ymm4) \ + vfmadd231pd(ymm1, ymm3, ymm5) \ +\ + vmovsd(mem(b, r11, 1), xmm3) \ + vfmadd231pd(ymm0, ymm3, ymm7) \ + vfmadd231pd(ymm1, ymm3, ymm8) \ +\ + vmovsd(mem(b, r11, 2), xmm3) \ + vfmadd231pd(ymm0, ymm3, ymm10) \ + vfmadd231pd(ymm1, ymm3, ymm11) \ +\ + vmovsd(mem(b, r13, 1), xmm3) \ + add(imm(1*8), b) \ + vfmadd231pd(ymm0, ymm3, ymm13) \ + vfmadd231pd(ymm1, ymm3, ymm14) \ + +#define SUBITER_K4_1x4(a, b) \ +\ + vmovupd(mem(a ), ymm0) \ + add(imm(4*8), a) \ + vmovupd(mem(b ), ymm3) \ + vfmadd231pd(ymm0, ymm3, ymm4) \ + vmovupd(mem(b, r11, 1), ymm3) \ + vfmadd231pd(ymm0, ymm3, ymm7) \ + vmovupd(mem(b, r11, 2), ymm3) \ + vfmadd231pd(ymm0, ymm3, ymm10) \ + vmovupd(mem(b, r13, 1), ymm3) \ + add(imm(4*8), b) \ + vfmadd231pd(ymm0, ymm3, ymm13) \ + +#define SUBITER_K1_1x4(a, b) \ +\ + vmovsd(mem(a ), xmm0) \ + add(imm(1*8), a) \ + vmovsd(mem(b ), xmm3) \ + vfmadd231pd(ymm0, ymm3, ymm4) \ + vmovsd(mem(b, r11, 1), xmm3) \ + vfmadd231pd(ymm0, ymm3, ymm7) \ + vmovsd(mem(b, r11, 2), xmm3) \ + vfmadd231pd(ymm0, ymm3, ymm10) \ + vmovsd(mem(b, r13, 1), xmm3) \ + add(imm(1*8), b) \ + vfmadd231pd(ymm0, ymm3, ymm13) \ + +/* + +Following kernel computes the 6x8 block for the Upper vairant(U) of gemmt where +m_offset in 24x24 block is 0 and n_offset is 0(0x0) +(0x0)_U + +the region marked with 'x' is computed by following kernel +the region marked with '-' is not computed + + <-- n_off_24 -- > + 0 1 2 3 4 5 6 7 + +ā†‘ 0 x x x x x x x x +| 1 - x x x x x x x +m 2 - - x x x x x x +off 3 - - - x x x x x +24 4 - - - - x x x x +| 5 - - - - - x x x +ā†“ + + +*/ +void bli_dgemmsup_rd_haswell_asm_6x8m_0x0_U + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_7); + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t m_iter = m0 / 3; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + begin_asm() + + mov(var(rs_a), r8) // load rs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + mov(var(cs_b), r11) // load cs_b + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a + + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + mov(imm(0), r15) // jj = 0; + +// ----------------------- Block 1 + + mov(var(a), r14) // load address of a + mov(var(b), rdx) // load address of b + mov(var(c), r12) // load address of c + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(imm(1*8), rsi) // rsi *= cs_c*sizeof(double) = 1*8 + lea(mem(r12, rsi, 1), r12) // r12 = c + 4*jj*cs_c; + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(r11, rsi) // rsi *= cs_b; + lea(mem(rdx, rsi, 1), rdx) // rbx = b + 4*jj*cs_b; + + mov(var(m_iter), r9) // ii = m_iter; + + vxorpd(ymm4, ymm4, ymm4) + vmovapd( ymm4, ymm5) + vmovapd( ymm4, ymm6) + vmovapd( ymm4, ymm7) + vmovapd( ymm4, ymm8) + vmovapd( ymm4, ymm9) + vmovapd( ymm4, ymm10) + vmovapd( ymm4, ymm11) + vmovapd( ymm4, ymm12) + vmovapd( ymm4, ymm13) + vmovapd( ymm4, ymm14) + vmovapd( ymm4, ymm15) + + lea(mem(r12), rcx) // rcx = c_iijj; + lea(mem(r14), rax) // rax = a_ii; + lea(mem(rdx), rbx) // rbx = b_jj; + + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c + lea(mem(r8, r8, 4), rbp) // rbp = 5*rs_a + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4_BLOCK1) // if i == 0, jump to code that + // contains the k_iter4 loop. + + label(.DLOOPKITER16_BLOCK1) // MAIN LOOP + // ---------------------------------- iteration 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + // ---------------------------------- iteration 1 + SUBITER_K4_3x4(rax, rbx) + // ---------------------------------- iteration 2 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + // ---------------------------------- iteration 3 + SUBITER_K4_3x4(rax, rbx) + dec(rsi) // i -= 1; + jne(.DLOOPKITER16_BLOCK1) // iterate again if i != 0. + + label(.DCONSIDKITER4_BLOCK1) + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1_BLOCK1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + label(.DLOOPKITER4_BLOCK1) // EDGE LOOP (ymm) + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + + SUBITER_K4_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4_BLOCK1) // iterate again if i != 0. + + label(.DCONSIDKLEFT1_BLOCK1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM_BLOCK1) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + label(.DLOOPKLEFT1_BLOCK1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + SUBITER_K1_3x4(rax, rbx) + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1_BLOCK1) // iterate again if i != 0. + label(.DPOSTACCUM_BLOCK1) + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + // xmm4[0] = sum(ymm4); xmm4[1] = sum(ymm7) + // xmm4[2] = sum(ymm10); xmm4[3] = sum(ymm13) + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + // xmm5[0] = sum(ymm5); xmm5[1] = sum(ymm8) + // xmm5[2] = sum(ymm11); xmm5[3] = sum(ymm14) + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + // xmm6[0] = sum(ymm6); xmm6[1] = sum(ymm9) + // xmm6[2] = sum(ymm12); xmm6[3] = sum(ymm15) + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + // now avoid loading C if beta == 0 + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO_BLOCK1) // if ZF = 1, jump to beta == 0 case + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vextractf128(imm(1), ymm5, xmm1 ) + vmovhpd(xmm5, mem(rcx, 1*8)) + vmovupd(xmm1, mem(rcx ,2*8)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vextractf128(imm(1), ymm6, xmm1 ) + vmovupd(xmm1, mem(rcx, 2*8)) + jmp(.DDONE_BLOCK1) // jump to end. + label(.DBETAZERO_BLOCK1) + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vextractf128(imm(1), ymm5, xmm1 ) + vmovhpd(xmm5, mem(rcx, 1*8)) + vmovupd(xmm1, mem(rcx ,2*8)) + add(rdi, rcx) + vextractf128(imm(1), ymm6, xmm1 ) + vmovupd(xmm1, mem(rcx, 2*8)) + label(.DDONE_BLOCK1) + lea(mem(r12, rdi, 2), r12) // + lea(mem(r12, rdi, 1), r12) // c_ii = r12 += 3*rs_c + + lea(mem(r14, r8, 2), r14) // + lea(mem(r14, r8, 1), r14) // a_ii = r14 += 3*rs_a + dec(r9) // ii -= 1; + +// ----------------------- Block 2 + vxorpd(ymm4, ymm4, ymm4) + vmovapd( ymm4, ymm5) + vmovapd( ymm4, ymm6) + vmovapd( ymm4, ymm7) + vmovapd( ymm4, ymm8) + vmovapd( ymm4, ymm9) + vmovapd( ymm4, ymm10) + vmovapd( ymm4, ymm11) + vmovapd( ymm4, ymm12) + vmovapd( ymm4, ymm13) + vmovapd( ymm4, ymm14) + vmovapd( ymm4, ymm15) + + lea(mem(r12), rcx) // rcx = c_iijj; + lea(mem(r14), rax) // rax = a_ii; + lea(mem(rdx), rbx) // rbx = b_jj; + + + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c + lea(mem(r8, r8, 4), rbp) // rbp = 5*rs_a + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4_BLOCK2) // if i == 0, jump to code that + // contains the k_iter4 loop. + + label(.DLOOPKITER16_BLOCK2) // MAIN LOOP + + // ---------------------------------- iteration 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + SUBITER_K4_1x4(rax, rbx) + // ---------------------------------- iteration 1 + SUBITER_K4_1x4(rax, rbx) + // ---------------------------------- iteration 2 + prefetch(0, mem(rax, r10, 1, 0*8)) + SUBITER_K4_1x4(rax, rbx) + // ---------------------------------- iteration 3 + SUBITER_K4_1x4(rax, rbx) + dec(rsi) // i -= 1; + jne(.DLOOPKITER16_BLOCK2) // iterate again if i != 0. + + label(.DCONSIDKITER4_BLOCK2) + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1_BLOCK2) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + label(.DLOOPKITER4_BLOCK2) // EDGE LOOP (ymm) + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + SUBITER_K4_1x4(rax, rbx) + dec(rsi) // i -= 1; + jne(.DLOOPKITER4_BLOCK2) // iterate again if i != 0. + + label(.DCONSIDKLEFT1_BLOCK2) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM_BLOCK2) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + label(.DLOOPKLEFT1_BLOCK2) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + SUBITER_K1_1x4(rax, rbx) + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1_BLOCK2) // iterate again if i != 0. + + label(.DPOSTACCUM_BLOCK2) + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + // xmm4[0] = sum(ymm4); xmm4[1] = sum(ymm7) + // xmm4[2] = sum(ymm10); xmm4[3] = sum(ymm13) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO_BLOCK2) // if ZF = 1, jump to beta == 0 case + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vextractf128(imm(1), ymm4, xmm1 ) + vmovhpd(xmm1, mem(rcx ,3*8)) + jmp(.DDONE_BLOCK2) // jump to end. + label(.DBETAZERO_BLOCK2) + + vextractf128(imm(1), ymm4, xmm1 ) + vmovhpd(xmm1, mem(rcx ,3*8)) + label(.DDONE_BLOCK2) + lea(mem(r12, rdi, 2), r12) // + lea(mem(r12, rdi, 1), r12) // c_ii = r12 += 3*rs_c + + lea(mem(r14, r8, 2), r14) // + lea(mem(r14, r8, 1), r14) // a_ii = r14 += 3*rs_a + dec(r9) // ii -= 1; + add(imm(4), r15) // jj += 4; + +// ----------------------- Block 3 + mov(var(a), r14) // load address of a + mov(var(b), rdx) // load address of b + mov(var(c), r12) // load address of c + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(imm(1*8), rsi) // rsi *= cs_c*sizeof(double) = 1*8 + lea(mem(r12, rsi, 1), r12) // r12 = c + 4*jj*cs_c; + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(r11, rsi) // rsi *= cs_b; + lea(mem(rdx, rsi, 1), rdx) // rbx = b + 4*jj*cs_b; + + mov(var(m_iter), r9) // ii = m_iter; + + vxorpd(ymm4, ymm4, ymm4) + vmovapd( ymm4, ymm5) + vmovapd( ymm4, ymm6) + vmovapd( ymm4, ymm7) + vmovapd( ymm4, ymm8) + vmovapd( ymm4, ymm9) + vmovapd( ymm4, ymm10) + vmovapd( ymm4, ymm11) + vmovapd( ymm4, ymm12) + vmovapd( ymm4, ymm13) + vmovapd( ymm4, ymm14) + vmovapd( ymm4, ymm15) + + lea(mem(r12), rcx) // rcx = c_iijj; + lea(mem(r14), rax) // rax = a_ii; + lea(mem(rdx), rbx) // rbx = b_jj; + + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c + lea(mem(r8, r8, 4), rbp) // rbp = 5*rs_a + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4_BLOCK3) // if i == 0, jump to code that + // contains the k_iter4 loop. + label(.DLOOPKITER16_BLOCK3) // MAIN LOOP + // ---------------------------------- iteration 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + // ---------------------------------- iteration 1 + SUBITER_K4_3x4(rax, rbx) + // ---------------------------------- iteration 2 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + // ---------------------------------- iteration 3 + SUBITER_K4_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16_BLOCK3) // iterate again if i != 0. + + label(.DCONSIDKITER4_BLOCK3) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1_BLOCK3) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + label(.DLOOPKITER4_BLOCK3) // EDGE LOOP (ymm) + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4_BLOCK3) // iterate again if i != 0. + label(.DCONSIDKLEFT1_BLOCK3) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM_BLOCK3) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + label(.DLOOPKLEFT1_BLOCK3) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + SUBITER_K1_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1_BLOCK3) // iterate again if i != 0. + label(.DPOSTACCUM_BLOCK3) + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + // xmm4[0] = sum(ymm4); xmm4[1] = sum(ymm7) + // xmm4[2] = sum(ymm10); xmm4[3] = sum(ymm13) + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + // xmm5[0] = sum(ymm5); xmm5[1] = sum(ymm8) + // xmm5[2] = sum(ymm11); xmm5[3] = sum(ymm14) + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + // xmm6[0] = sum(ymm6); xmm6[1] = sum(ymm9) + // xmm6[2] = sum(ymm12); xmm6[3] = sum(ymm15) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO_BLOCK3) // if ZF = 1, jump to beta == 0 case + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + + jmp(.DDONE_BLOCK3) // jump to end. + + label(.DBETAZERO_BLOCK3) + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx)) + + label(.DDONE_BLOCK3) + + lea(mem(r12, rdi, 2), r12) + lea(mem(r12, rdi, 1), r12) // c_ii = r12 += 3*rs_c + + lea(mem(r14, r8, 2), r14) // + lea(mem(r14, r8, 1), r14) // a_ii = r14 += 3*rs_a + + dec(r9) // ii -= 1; + +// ----------------------- Block 4 + vxorpd(ymm4, ymm4, ymm4) + vmovapd( ymm4, ymm5) + vmovapd( ymm4, ymm6) + vmovapd( ymm4, ymm7) + vmovapd( ymm4, ymm8) + vmovapd( ymm4, ymm9) + vmovapd( ymm4, ymm10) + vmovapd( ymm4, ymm11) + vmovapd( ymm4, ymm12) + vmovapd( ymm4, ymm13) + vmovapd( ymm4, ymm14) + vmovapd( ymm4, ymm15) + + lea(mem(r12), rcx) // rcx = c_iijj; + lea(mem(r14), rax) // rax = a_ii; + lea(mem(rdx), rbx) // rbx = b_jj; + + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c + lea(mem(r8, r8, 4), rbp) // rbp = 5*rs_a + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4_BLOCK4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + label(.DLOOPKITER16_BLOCK4) // MAIN LOOP + // ---------------------------------- iteration 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + // ---------------------------------- iteration 1 + SUBITER_K4_3x4(rax, rbx) + // ---------------------------------- iteration 2 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + // ---------------------------------- iteration 3 + SUBITER_K4_3x4(rax, rbx) + dec(rsi) // i -= 1; + jne(.DLOOPKITER16_BLOCK4) // iterate again if i != 0. + + label(.DCONSIDKITER4_BLOCK4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1_BLOCK4) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + label(.DLOOPKITER4_BLOCK4) // EDGE LOOP (ymm) + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + + SUBITER_K4_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4_BLOCK4) // iterate again if i != 0. + + label(.DCONSIDKLEFT1_BLOCK4) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM_BLOCK4) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + label(.DLOOPKLEFT1_BLOCK4) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + SUBITER_K1_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1_BLOCK4) // iterate again if i != 0. + + label(.DPOSTACCUM_BLOCK4) + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + // xmm4[0] = sum(ymm4); xmm4[1] = sum(ymm7) + // xmm4[2] = sum(ymm10); xmm4[3] = sum(ymm13) + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + // xmm5[0] = sum(ymm5); xmm5[1] = sum(ymm8) + // xmm5[2] = sum(ymm11); xmm5[3] = sum(ymm14) + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + // xmm6[0] = sum(ymm6); xmm6[1] = sum(ymm9) + // xmm6[2] = sum(ymm12); xmm6[3] = sum(ymm15) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO_BLOCK4) // if ZF = 1, jump to beta == 0 case + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vextractf128(imm(1), ymm6, xmm1 ) + vmovhpd(xmm6, mem(rcx, 1*8)) + vmovupd(xmm1, mem(rcx, 2*8)) + + + jmp(.DDONE_BLOCK4) // jump to end. + + label(.DBETAZERO_BLOCK4) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vextractf128(imm(1), ymm6, xmm1 ) + vmovhpd(xmm6, mem(rcx, 1*8)) + vmovupd(xmm1, mem(rcx, 2*8)) + + + label(.DDONE_BLOCK4) + + vzeroupper() + + end_asm( + : // output operands (none) + : // input operands + [m_iter] "m" (m_iter), + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7); +} + +/* + +Following kernel computes the 6x8 block for the Upper vairant(U) of gemmt where +m_offset in 24x24 block is 6 and n_offset is 0(6x0) +(6x0)_U + + +the region marked with 'x' is computed by following kernel +the region marked with '-' is not computed + + <-- n_off_24 -- > + 0 1 2 3 4 5 6 7 + +ā†‘ 6 - - - - - - x x +| 7 - - - - - - - x +m 8 - - - - - - - - +off 9 - - - - - - - - +24 10 - - - - - - - - +| 11 - - - - - - - - +ā†“ + + +*/ +void bli_dgemmsup_rd_haswell_asm_6x8m_6x0_U + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_7); + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t m_iter = m0 / 3; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + begin_asm() + + mov(var(rs_a), r8) // load rs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + mov(var(cs_b), r11) // load cs_b + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a + + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + mov(imm(4), r15) // jj = 4; +// ----------------------- Block 3 + mov(var(a), r14) // load address of a + mov(var(b), rdx) // load address of b + mov(var(c), r12) // load address of c + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(imm(1*8), rsi) // rsi *= cs_c*sizeof(double) = 1*8 + lea(mem(r12, rsi, 1), r12) // r12 = c + 4*jj*cs_c; + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(r11, rsi) // rsi *= cs_b; + lea(mem(rdx, rsi, 1), rdx) // rbx = b + 4*jj*cs_b; + + mov(var(m_iter), r9) // ii = m_iter; + +// ----------------------- Block 3 + vxorpd(ymm4, ymm4, ymm4) + vmovapd( ymm4, ymm5) + vmovapd( ymm4, ymm6) + vmovapd( ymm4, ymm7) + vmovapd( ymm4, ymm8) + vmovapd( ymm4, ymm9) + vmovapd( ymm4, ymm10) + vmovapd( ymm4, ymm11) + vmovapd( ymm4, ymm12) + vmovapd( ymm4, ymm13) + vmovapd( ymm4, ymm14) + vmovapd( ymm4, ymm15) + + lea(mem(r12), rcx) // rcx = c_iijj; + lea(mem(r14), rax) // rax = a_ii; + lea(mem(rdx), rbx) // rbx = b_jj; + + + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c + lea(mem(r8, r8, 4), rbp) // rbp = 5*rs_a + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4_BLOCK3) // if i == 0, jump to code that + // contains the k_iter4 loop. + + label(.DLOOPKITER16_BLOCK3) // MAIN LOOP + // ---------------------------------- iteration 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_2x4(rax, rbx) + // ---------------------------------- iteration 1 + SUBITER_K4_2x4(rax, rbx) + // ---------------------------------- iteration 2 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_2x4(rax, rbx) + // ---------------------------------- iteration 3 + SUBITER_K4_2x4(rax, rbx) + dec(rsi) // i -= 1; + jne(.DLOOPKITER16_BLOCK3) // iterate again if i != 0. + + label(.DCONSIDKITER4_BLOCK3) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1_BLOCK3) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + label(.DLOOPKITER4_BLOCK3) // EDGE LOOP (ymm) + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_2x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4_BLOCK3) // iterate again if i != 0. + + label(.DCONSIDKLEFT1_BLOCK3) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM_BLOCK3) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + label(.DLOOPKLEFT1_BLOCK3) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + SUBITER_K1_2x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1_BLOCK3) // iterate again if i != 0. + + label(.DPOSTACCUM_BLOCK3) + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + // xmm4[0] = sum(ymm4); xmm4[1] = sum(ymm7) + // xmm4[2] = sum(ymm10); xmm4[3] = sum(ymm13) + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + // xmm5[0] = sum(ymm5); xmm5[1] = sum(ymm8) + // xmm5[2] = sum(ymm11); xmm5[3] = sum(ymm14) + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + // xmm6[0] = sum(ymm6); xmm6[1] = sum(ymm9) + // xmm6[2] = sum(ymm12); xmm6[3] = sum(ymm15) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO_BLOCK3) // if ZF = 1, jump to beta == 0 case + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vextractf128(imm(1), ymm4, xmm1 ) + vmovupd(xmm1, mem(rcx, 2*8)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vextractf128(imm(1), ymm5, xmm1 ) + vmovhpd(xmm1, mem(rcx, 3*8)) + add(rdi, rcx) + + // vfmadd231pd(mem(rcx), ymm3, ymm6) + // vmovupd(ymm6, mem(rcx)) + + + jmp(.DDONE_BLOCK3) // jump to end. + + label(.DBETAZERO_BLOCK3) + + + vextractf128(imm(1), ymm4, xmm1 ) + vmovupd(xmm1, mem(rcx, 2*8)) + add(rdi, rcx) + + vextractf128(imm(1), ymm5, xmm1 ) + vmovhpd(xmm1, mem(rcx, 3*8)) + add(rdi, rcx) + + + label(.DDONE_BLOCK3) + + vzeroupper() + + end_asm( + : // output operands (none) + : // input operands + [m_iter] "m" (m_iter), + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7); +} + +/* + +Following kernel computes the 6x8 block for the Upper vairant(U) of gemmt where +m_offset in 24x24 block is 0, n_offset is 0(0x0) and m_offset is 6, n_offset is 0 (6x0) +(0x0)+(6x0)_L + +the region marked with 'x' is computed by following kernel +the region marked with '-' is not computed + + <-- n_off_24 -- > + 0 1 2 3 4 5 6 7 + +ā†‘ 0 x x x x x x x x +| 1 - x x x x x x x +m 2 - - x x x x x x +off 3 - - - x x x x x +24 4 - - - - x x x x +| 5 - - - - - x x x +ā†“ +ā†‘ 6 - - - - - - x x +| 7 - - - - - - - x +m 8 - - - - - - - - +off 9 - - - - - - - - +24 10 - - - - - - - - +| 11 - - - - - - - - +ā†“ + + +*/ +void bli_dgemmsup_rd_haswell_asm_6x8m_0x0_combined_U + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_7); + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t m_iter = m0 / 3; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + begin_asm() + + mov(var(rs_a), r8) // load rs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + mov(var(cs_b), r11) // load cs_b + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a + + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + mov(imm(0), r15) // jj = 0; + +// ----------------------- Block 1 + + mov(var(a), r14) // load address of a + mov(var(b), rdx) // load address of b + mov(var(c), r12) // load address of c + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(imm(1*8), rsi) // rsi *= cs_c*sizeof(double) = 1*8 + lea(mem(r12, rsi, 1), r12) // r12 = c + 4*jj*cs_c; + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(r11, rsi) // rsi *= cs_b; + lea(mem(rdx, rsi, 1), rdx) // rbx = b + 4*jj*cs_b; + + mov(var(m_iter), r9) // ii = m_iter; + + vxorpd(ymm4, ymm4, ymm4) + vmovapd( ymm4, ymm5) + vmovapd( ymm4, ymm6) + vmovapd( ymm4, ymm7) + vmovapd( ymm4, ymm8) + vmovapd( ymm4, ymm9) + vmovapd( ymm4, ymm10) + vmovapd( ymm4, ymm11) + vmovapd( ymm4, ymm12) + vmovapd( ymm4, ymm13) + vmovapd( ymm4, ymm14) + vmovapd( ymm4, ymm15) + + lea(mem(r12), rcx) // rcx = c_iijj; + lea(mem(r14), rax) // rax = a_ii; + lea(mem(rdx), rbx) // rbx = b_jj; + + + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c + lea(mem(r8, r8, 4), rbp) // rbp = 5*rs_a + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4_BLOCK1) // if i == 0, jump to code that + // contains the k_iter4 loop. + + label(.DLOOPKITER16_BLOCK1) // MAIN LOOP + + // ---------------------------------- iteration 0 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + // ---------------------------------- iteration 1 + SUBITER_K4_3x4(rax, rbx) + // ---------------------------------- iteration 2 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + // ---------------------------------- iteration 3 + SUBITER_K4_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16_BLOCK1) // iterate again if i != 0. + + label(.DCONSIDKITER4_BLOCK1) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1_BLOCK1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + label(.DLOOPKITER4_BLOCK1) // EDGE LOOP (ymm) + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4_BLOCK1) // iterate again if i != 0. + + label(.DCONSIDKLEFT1_BLOCK1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM_BLOCK1) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + label(.DLOOPKLEFT1_BLOCK1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + SUBITER_K1_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1_BLOCK1) // iterate again if i != 0. + + label(.DPOSTACCUM_BLOCK1) + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + // xmm4[0] = sum(ymm4); xmm4[1] = sum(ymm7) + // xmm4[2] = sum(ymm10); xmm4[3] = sum(ymm13) + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + // xmm5[0] = sum(ymm5); xmm5[1] = sum(ymm8) + // xmm5[2] = sum(ymm11); xmm5[3] = sum(ymm14) + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + // xmm6[0] = sum(ymm6); xmm6[1] = sum(ymm9) + // xmm6[2] = sum(ymm12); xmm6[3] = sum(ymm15) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO_BLOCK1) // if ZF = 1, jump to beta == 0 case + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vextractf128(imm(1), ymm5, xmm1 ) + vmovhpd(xmm5, mem(rcx, 1*8)) + vmovupd(xmm1, mem(rcx ,2*8)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vextractf128(imm(1), ymm6, xmm1 ) + vmovupd(xmm1, mem(rcx, 2*8)) + + + jmp(.DDONE_BLOCK1) // jump to end. + + label(.DBETAZERO_BLOCK1) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vextractf128(imm(1), ymm5, xmm1 ) + vmovhpd(xmm5, mem(rcx, 1*8)) + vmovupd(xmm1, mem(rcx ,2*8)) + + add(rdi, rcx) + + vextractf128(imm(1), ymm6, xmm1 ) + vmovupd(xmm1, mem(rcx, 2*8)) + + + + label(.DDONE_BLOCK1) + + lea(mem(r12, rdi, 2), r12) // + lea(mem(r12, rdi, 1), r12) // c_ii = r12 += 3*rs_c + + lea(mem(r14, r8, 2), r14) // + lea(mem(r14, r8, 1), r14) // a_ii = r14 += 3*rs_a + + dec(r9) // ii -= 1; + +// ----------------------- Block 2 + vxorpd(ymm4, ymm4, ymm4) + vmovapd( ymm4, ymm5) + vmovapd( ymm4, ymm6) + vmovapd( ymm4, ymm7) + vmovapd( ymm4, ymm8) + vmovapd( ymm4, ymm9) + vmovapd( ymm4, ymm10) + vmovapd( ymm4, ymm11) + vmovapd( ymm4, ymm12) + vmovapd( ymm4, ymm13) + vmovapd( ymm4, ymm14) + vmovapd( ymm4, ymm15) + + lea(mem(r12), rcx) // rcx = c_iijj; + lea(mem(r14), rax) // rax = a_ii; + lea(mem(rdx), rbx) // rbx = b_jj; + + + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c + lea(mem(r8, r8, 4), rbp) // rbp = 5*rs_a + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4_BLOCK2) // if i == 0, jump to code that + // contains the k_iter4 loop. + + label(.DLOOPKITER16_BLOCK2) // MAIN LOOP + + // ---------------------------------- iteration 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + SUBITER_K4_1x4(rax, rbx) + // ---------------------------------- iteration 1 + SUBITER_K4_1x4(rax, rbx) + // ---------------------------------- iteration 2 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + SUBITER_K4_1x4(rax, rbx) + // ---------------------------------- iteration 3 + SUBITER_K4_1x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16_BLOCK2) // iterate again if i != 0. + + label(.DCONSIDKITER4_BLOCK2) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1_BLOCK2) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + label(.DLOOPKITER4_BLOCK2) // EDGE LOOP (ymm) + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + SUBITER_K4_1x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4_BLOCK2) // iterate again if i != 0. + + label(.DCONSIDKLEFT1_BLOCK2) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM_BLOCK2) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + label(.DLOOPKLEFT1_BLOCK2) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + SUBITER_K1_1x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1_BLOCK2) // iterate again if i != 0. + + label(.DPOSTACCUM_BLOCK2) + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + // xmm4[0] = sum(ymm4); xmm4[1] = sum(ymm7) + // xmm4[2] = sum(ymm10); xmm4[3] = sum(ymm13) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO_BLOCK2) // if ZF = 1, jump to beta == 0 case + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vextractf128(imm(1), ymm4, xmm1 ) + vmovhpd(xmm1, mem(rcx ,3*8)) + + + jmp(.DDONE_BLOCK2) // jump to end. + + label(.DBETAZERO_BLOCK2) + + + vextractf128(imm(1), ymm4, xmm1 ) + vmovhpd(xmm1, mem(rcx ,3*8)) + + + label(.DDONE_BLOCK2) + + lea(mem(r12, rdi, 2), r12) // + lea(mem(r12, rdi, 1), r12) // c_ii = r12 += 3*rs_c + + lea(mem(r14, r8, 2), r14) // + lea(mem(r14, r8, 1), r14) // a_ii = r14 += 3*rs_a + + dec(r9) // ii -= 1; + + add(imm(4), r15) // jj += 4; + +// ----------------------- Block 3 + mov(var(a), r14) // load address of a + mov(var(b), rdx) // load address of b + mov(var(c), r12) // load address of c + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(imm(1*8), rsi) // rsi *= cs_c*sizeof(double) = 1*8 + lea(mem(r12, rsi, 1), r12) // r12 = c + 4*jj*cs_c; + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(r11, rsi) // rsi *= cs_b; + lea(mem(rdx, rsi, 1), rdx) // rbx = b + 4*jj*cs_b; + + mov(var(m_iter), r9) // ii = m_iter; + + vxorpd(ymm4, ymm4, ymm4) + vmovapd( ymm4, ymm5) + vmovapd( ymm4, ymm6) + vmovapd( ymm4, ymm7) + vmovapd( ymm4, ymm8) + vmovapd( ymm4, ymm9) + vmovapd( ymm4, ymm10) + vmovapd( ymm4, ymm11) + vmovapd( ymm4, ymm12) + vmovapd( ymm4, ymm13) + vmovapd( ymm4, ymm14) + vmovapd( ymm4, ymm15) + + lea(mem(r12), rcx) // rcx = c_iijj; + lea(mem(r14), rax) // rax = a_ii; + lea(mem(rdx), rbx) // rbx = b_jj; + + + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c + lea(mem(r8, r8, 4), rbp) // rbp = 5*rs_a + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4_BLOCK3) // if i == 0, jump to code that + // contains the k_iter4 loop. + + label(.DLOOPKITER16_BLOCK3) // MAIN LOOP + + // ---------------------------------- iteration 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + // ---------------------------------- iteration 1 + SUBITER_K4_3x4(rax, rbx) + // ---------------------------------- iteration 2 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + // ---------------------------------- iteration 3 + SUBITER_K4_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16_BLOCK3) // iterate again if i != 0. + + label(.DCONSIDKITER4_BLOCK3) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1_BLOCK3) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + label(.DLOOPKITER4_BLOCK3) // EDGE LOOP (ymm) + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4_BLOCK3) // iterate again if i != 0. + + label(.DCONSIDKLEFT1_BLOCK3) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM_BLOCK3) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + label(.DLOOPKLEFT1_BLOCK3) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + SUBITER_K1_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1_BLOCK3) // iterate again if i != 0. + + label(.DPOSTACCUM_BLOCK3) + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + // xmm4[0] = sum(ymm4); xmm4[1] = sum(ymm7) + // xmm4[2] = sum(ymm10); xmm4[3] = sum(ymm13) + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + // xmm5[0] = sum(ymm5); xmm5[1] = sum(ymm8) + // xmm5[2] = sum(ymm11); xmm5[3] = sum(ymm14) + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + // xmm6[0] = sum(ymm6); xmm6[1] = sum(ymm9) + // xmm6[2] = sum(ymm12); xmm6[3] = sum(ymm15) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO_BLOCK3) // if ZF = 1, jump to beta == 0 case + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + + + jmp(.DDONE_BLOCK3) // jump to end. + + label(.DBETAZERO_BLOCK3) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx)) + + + label(.DDONE_BLOCK3) + + lea(mem(r12, rdi, 2), r12) // + lea(mem(r12, rdi, 1), r12) // c_ii = r12 += 3*rs_c + + lea(mem(r14, r8, 2), r14) // + lea(mem(r14, r8, 1), r14) // a_ii = r14 += 3*rs_a + + dec(r9) // ii -= 1; + +// ----------------------- Block 4 + vxorpd(ymm4, ymm4, ymm4) + vmovapd( ymm4, ymm5) + vmovapd( ymm4, ymm6) + vmovapd( ymm4, ymm7) + vmovapd( ymm4, ymm8) + vmovapd( ymm4, ymm9) + vmovapd( ymm4, ymm10) + vmovapd( ymm4, ymm11) + vmovapd( ymm4, ymm12) + vmovapd( ymm4, ymm13) + vmovapd( ymm4, ymm14) + vmovapd( ymm4, ymm15) + + lea(mem(r12), rcx) // rcx = c_iijj; + lea(mem(r14), rax) // rax = a_ii; + lea(mem(rdx), rbx) // rbx = b_jj; + + + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c + lea(mem(r8, r8, 4), rbp) // rbp = 5*rs_a + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4_BLOCK4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + label(.DLOOPKITER16_BLOCK4) // MAIN LOOP + + // ---------------------------------- iteration 0 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + // ---------------------------------- iteration 1 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + // ---------------------------------- iteration 2 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + // ---------------------------------- iteration 3 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16_BLOCK4) // iterate again if i != 0. + + label(.DCONSIDKITER4_BLOCK4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1_BLOCK4) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + label(.DLOOPKITER4_BLOCK4) // EDGE LOOP (ymm) + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4_BLOCK4) // iterate again if i != 0. + + label(.DCONSIDKLEFT1_BLOCK4) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM_BLOCK4) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + label(.DLOOPKLEFT1_BLOCK4) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax ), xmm0) + vmovsd(mem(rax, r8, 1), xmm1) + vmovsd(mem(rax, r8, 2), xmm2) + add(imm(1*8), rax) // a += 1*cs_a = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1_BLOCK4) // iterate again if i != 0. + + label(.DPOSTACCUM_BLOCK4) + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + // xmm4[0] = sum(ymm4); xmm4[1] = sum(ymm7) + // xmm4[2] = sum(ymm10); xmm4[3] = sum(ymm13) + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + // xmm5[0] = sum(ymm5); xmm5[1] = sum(ymm8) + // xmm5[2] = sum(ymm11); xmm5[3] = sum(ymm14) + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + // xmm6[0] = sum(ymm6); xmm6[1] = sum(ymm9) + // xmm6[2] = sum(ymm12); xmm6[3] = sum(ymm15) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO_BLOCK4) // if ZF = 1, jump to beta == 0 case + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vextractf128(imm(1), ymm6, xmm1 ) + vmovhpd(xmm6, mem(rcx, 1*8)) + vmovupd(xmm1, mem(rcx, 2*8)) + + + jmp(.DDONE_BLOCK4) // jump to end. + + label(.DBETAZERO_BLOCK4) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vextractf128(imm(1), ymm6, xmm1 ) + vmovhpd(xmm6, mem(rcx, 1*8)) + vmovupd(xmm1, mem(rcx, 2*8)) + + + label(.DDONE_BLOCK4) + + vzeroupper() + + end_asm( + : // output operands (none) + : // input operands + [m_iter] "m" (m_iter), + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + a += 6 * rs_a; + c += 6 * rs_c; + + begin_asm() + + mov(var(rs_a), r8) // load rs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + mov(var(cs_b), r11) // load cs_b + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a + + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + mov(imm(0), r15) // jj = 0; + +// ----------------------- Block 1 + mov(var(a), r14) // load address of a + mov(var(b), rdx) // load address of b + mov(var(c), r12) // load address of c + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(imm(1*8), rsi) // rsi *= cs_c*sizeof(double) = 1*8 + lea(mem(r12, rsi, 1), r12) // r12 = c + 4*jj*cs_c; + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(r11, rsi) // rsi *= cs_b; + lea(mem(rdx, rsi, 1), rdx) // rbx = b + 4*jj*cs_b; + + mov(var(m_iter), r9) // ii = m_iter; + + lea(mem(r12, rdi, 2), r12) // + lea(mem(r12, rdi, 1), r12) // c_ii = r12 += 3*rs_c + + lea(mem(r14, r8, 2), r14) // + lea(mem(r14, r8, 1), r14) // a_ii = r14 += 3*rs_a + + dec(r9) // ii -= 1; + +// ----------------------- Block 2 + + lea(mem(r12, rdi, 2), r12) // + lea(mem(r12, rdi, 1), r12) // c_ii = r12 += 3*rs_c + + lea(mem(r14, r8, 2), r14) // + lea(mem(r14, r8, 1), r14) // a_ii = r14 += 3*rs_a + + dec(r9) // ii -= 1; + + add(imm(4), r15) // jj += 4; + +// ----------------------- Block 3 + mov(var(a), r14) // load address of a + mov(var(b), rdx) // load address of b + mov(var(c), r12) // load address of c + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(imm(1*8), rsi) // rsi *= cs_c*sizeof(double) = 1*8 + lea(mem(r12, rsi, 1), r12) // r12 = c + 4*jj*cs_c; + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(r11, rsi) // rsi *= cs_b; + lea(mem(rdx, rsi, 1), rdx) // rbx = b + 4*jj*cs_b; + + mov(var(m_iter), r9) // ii = m_iter; + + vxorpd(ymm4, ymm4, ymm4) + vmovapd( ymm4, ymm5) + vmovapd( ymm4, ymm6) + vmovapd( ymm4, ymm7) + vmovapd( ymm4, ymm8) + vmovapd( ymm4, ymm9) + vmovapd( ymm4, ymm10) + vmovapd( ymm4, ymm11) + vmovapd( ymm4, ymm12) + vmovapd( ymm4, ymm13) + vmovapd( ymm4, ymm14) + vmovapd( ymm4, ymm15) + + lea(mem(r12), rcx) // rcx = c_iijj; + lea(mem(r14), rax) // rax = a_ii; + lea(mem(rdx), rbx) // rbx = b_jj; + + + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c + lea(mem(r8, r8, 4), rbp) // rbp = 5*rs_a + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4_BLOCK3) // if i == 0, jump to code that + // contains the k_iter4 loop. + + label(.DLOOPKITER16_BLOCK3) // MAIN LOOP + + // ---------------------------------- iteration 0 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + // ---------------------------------- iteration 1 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + // ---------------------------------- iteration 2 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + // ---------------------------------- iteration 3 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16_BLOCK3) // iterate again if i != 0. + + label(.DCONSIDKITER4_BLOCK3) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1_BLOCK3) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + label(.DLOOPKITER4_BLOCK3) // EDGE LOOP (ymm) + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4_BLOCK3) // iterate again if i != 0. + + label(.DCONSIDKLEFT1_BLOCK3) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM_BLOCK3) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + label(.DLOOPKLEFT1_BLOCK3) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax ), xmm0) + vmovsd(mem(rax, r8, 1), xmm1) + add(imm(1*8), rax) // a += 1*cs_a = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1_BLOCK3) // iterate again if i != 0. + + label(.DPOSTACCUM_BLOCK3) + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + // xmm4[0] = sum(ymm4); xmm4[1] = sum(ymm7) + // xmm4[2] = sum(ymm10); xmm4[3] = sum(ymm13) + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + // xmm5[0] = sum(ymm5); xmm5[1] = sum(ymm8) + // xmm5[2] = sum(ymm11); xmm5[3] = sum(ymm14) + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + // xmm6[0] = sum(ymm6); xmm6[1] = sum(ymm9) + // xmm6[2] = sum(ymm12); xmm6[3] = sum(ymm15) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO_BLOCK3) // if ZF = 1, jump to beta == 0 case + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vextractf128(imm(1), ymm4, xmm1 ) + vmovupd(xmm1, mem(rcx, 2*8)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vextractf128(imm(1), ymm5, xmm1 ) + vmovhpd(xmm1, mem(rcx, 3*8)) + add(rdi, rcx) + + // vfmadd231pd(mem(rcx), ymm3, ymm6) + // vmovupd(ymm6, mem(rcx)) + + + jmp(.DDONE_BLOCK3) // jump to end. + + label(.DBETAZERO_BLOCK3) + + + vextractf128(imm(1), ymm4, xmm1 ) + vmovupd(xmm1, mem(rcx, 2*8)) + add(rdi, rcx) + + vextractf128(imm(1), ymm5, xmm1 ) + vmovhpd(xmm1, mem(rcx, 3*8)) + add(rdi, rcx) + + + label(.DDONE_BLOCK3) + + vzeroupper() + + end_asm( + : // output operands (none) + : // input operands + [m_iter] "m" (m_iter), + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7); +} + +/* + +Following kernel computes the 6x8 block for the Upper vairant(U) of gemmt where +m_offset in 24x24 block is 6 and n_offset is 8(6x8) +(6x8)_U + +the region marked with 'x' is computed by following kernel +the region marked with '-' is not computed + + <-- n_off_24 -- > + 8 9 10 11 12 13 14 15 + +ā†‘ 6 x x x x x x x x +| 7 x x x x x x x x +m 8 x x x x x x x x +off 9 - x x x x x x x +24 10 - - x x x x x x +| 11 - - - x x x x x +ā†“ + + +*/ +void bli_dgemmsup_rd_haswell_asm_6x8m_6x8_U + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_7); + + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t m_iter = m0 / 3; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + begin_asm() + + mov(var(rs_a), r8) // load rs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + mov(var(cs_b), r11) // load cs_b + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a + + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + mov(imm(0), r15) // jj = 0; + +// ----------------------- Block 1 + + mov(var(a), r14) // load address of a + mov(var(b), rdx) // load address of b + mov(var(c), r12) // load address of c + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(imm(1*8), rsi) // rsi *= cs_c*sizeof(double) = 1*8 + lea(mem(r12, rsi, 1), r12) // r12 = c + 4*jj*cs_c; + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(r11, rsi) // rsi *= cs_b; + lea(mem(rdx, rsi, 1), rdx) // rbx = b + 4*jj*cs_b; + + mov(var(m_iter), r9) // ii = m_iter; + + vxorpd(ymm4, ymm4, ymm4) + vmovapd( ymm4, ymm5) + vmovapd( ymm4, ymm6) + vmovapd( ymm4, ymm7) + vmovapd( ymm4, ymm8) + vmovapd( ymm4, ymm9) + vmovapd( ymm4, ymm10) + vmovapd( ymm4, ymm11) + vmovapd( ymm4, ymm12) + vmovapd( ymm4, ymm13) + vmovapd( ymm4, ymm14) + vmovapd( ymm4, ymm15) + + lea(mem(r12), rcx) // rcx = c_iijj; + lea(mem(r14), rax) // rax = a_ii; + lea(mem(rdx), rbx) // rbx = b_jj; + + + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c + lea(mem(r8, r8, 4), rbp) // rbp = 5*rs_a + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4_BLOCK1) // if i == 0, jump to code that + // contains the k_iter4 loop. + + label(.DLOOPKITER16_BLOCK1) // MAIN LOOP + + // ---------------------------------- iteration 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + // ---------------------------------- iteration 1 + SUBITER_K4_3x4(rax, rbx) + // ---------------------------------- iteration 2 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + // ---------------------------------- iteration 3 + SUBITER_K4_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16_BLOCK1) // iterate again if i != 0. + + label(.DCONSIDKITER4_BLOCK1) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1_BLOCK1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + label(.DLOOPKITER4_BLOCK1) // EDGE LOOP (ymm) + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4_BLOCK1) // iterate again if i != 0. + + label(.DCONSIDKLEFT1_BLOCK1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM_BLOCK1) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + label(.DLOOPKLEFT1_BLOCK1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + SUBITER_K1_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1_BLOCK1) // iterate again if i != 0. + + label(.DPOSTACCUM_BLOCK1) + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + // xmm4[0] = sum(ymm4); xmm4[1] = sum(ymm7) + // xmm4[2] = sum(ymm10); xmm4[3] = sum(ymm13) + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + // xmm5[0] = sum(ymm5); xmm5[1] = sum(ymm8) + // xmm5[2] = sum(ymm11); xmm5[3] = sum(ymm14) + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + // xmm6[0] = sum(ymm6); xmm6[1] = sum(ymm9) + // xmm6[2] = sum(ymm12); xmm6[3] = sum(ymm15) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO_BLOCK1) // if ZF = 1, jump to beta == 0 case + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + + + jmp(.DDONE_BLOCK1) // jump to end. + + label(.DBETAZERO_BLOCK1) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx)) + + + label(.DDONE_BLOCK1) + + lea(mem(r12, rdi, 2), r12) // + lea(mem(r12, rdi, 1), r12) // c_ii = r12 += 3*rs_c + + lea(mem(r14, r8, 2), r14) // + lea(mem(r14, r8, 1), r14) // a_ii = r14 += 3*rs_a + + dec(r9) // ii -= 1; + +// ----------------------- Block 2 + vxorpd(ymm4, ymm4, ymm4) + vmovapd( ymm4, ymm5) + vmovapd( ymm4, ymm6) + vmovapd( ymm4, ymm7) + vmovapd( ymm4, ymm8) + vmovapd( ymm4, ymm9) + vmovapd( ymm4, ymm10) + vmovapd( ymm4, ymm11) + vmovapd( ymm4, ymm12) + vmovapd( ymm4, ymm13) + vmovapd( ymm4, ymm14) + vmovapd( ymm4, ymm15) + + lea(mem(r12), rcx) // rcx = c_iijj; + lea(mem(r14), rax) // rax = a_ii; + lea(mem(rdx), rbx) // rbx = b_jj; + + + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c + lea(mem(r8, r8, 4), rbp) // rbp = 5*rs_a + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4_BLOCK2) // if i == 0, jump to code that + // contains the k_iter4 loop. + + label(.DLOOPKITER16_BLOCK2) // MAIN LOOP + + // ---------------------------------- iteration 0 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + + // ---------------------------------- iteration 1 + SUBITER_K4_3x4(rax, rbx) + + // ---------------------------------- iteration 2 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + + // ---------------------------------- iteration 3 + SUBITER_K4_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16_BLOCK2) // iterate again if i != 0. + + label(.DCONSIDKITER4_BLOCK2) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1_BLOCK2) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + label(.DLOOPKITER4_BLOCK2) // EDGE LOOP (ymm) + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4_BLOCK2) // iterate again if i != 0. + + label(.DCONSIDKLEFT1_BLOCK2) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM_BLOCK2) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + label(.DLOOPKLEFT1_BLOCK2) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + SUBITER_K1_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1_BLOCK2) // iterate again if i != 0. + + label(.DPOSTACCUM_BLOCK2) + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + // xmm4[0] = sum(ymm4); xmm4[1] = sum(ymm7) + // xmm4[2] = sum(ymm10); xmm4[3] = sum(ymm13) + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + // xmm5[0] = sum(ymm5); xmm5[1] = sum(ymm8) + // xmm5[2] = sum(ymm11); xmm5[3] = sum(ymm14) + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + // xmm6[0] = sum(ymm6); xmm6[1] = sum(ymm9) + // xmm6[2] = sum(ymm12); xmm6[3] = sum(ymm15) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO_BLOCK2) // if ZF = 1, jump to beta == 0 case + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vextractf128(imm(1), ymm4, xmm1 ) + vmovhpd(xmm4, mem(rcx, 1*8)) + vmovupd(xmm1, mem(rcx, 2*8)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vextractf128(imm(1), ymm5, xmm1 ) + vmovupd(xmm1, mem(rcx, 2*8)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vextractf128(imm(1), ymm6, xmm1 ) + vmovhpd(xmm1, mem(rcx, 3*8)) + + + jmp(.DDONE_BLOCK2) // jump to end. + + label(.DBETAZERO_BLOCK2) + + + vextractf128(imm(1), ymm4, xmm1 ) + vmovhpd(xmm4, mem(rcx, 1*8)) + vmovupd(xmm1, mem(rcx, 2*8)) + add(rdi, rcx) + + vextractf128(imm(1), ymm5, xmm1 ) + vmovupd(xmm1, mem(rcx, 2*8)) + add(rdi, rcx) + + vextractf128(imm(1), ymm6, xmm1 ) + vmovhpd(xmm1, mem(rcx, 3*8)) + + + label(.DDONE_BLOCK2) + + lea(mem(r12, rdi, 2), r12) // + lea(mem(r12, rdi, 1), r12) // c_ii = r12 += 3*rs_c + + lea(mem(r14, r8, 2), r14) // + lea(mem(r14, r8, 1), r14) // a_ii = r14 += 3*rs_a + + dec(r9) // ii -= 1; + + add(imm(4), r15) // jj += 4; + +// ----------------------- Block 3 + mov(var(a), r14) // load address of a + mov(var(b), rdx) // load address of b + mov(var(c), r12) // load address of c + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(imm(1*8), rsi) // rsi *= cs_c*sizeof(double) = 1*8 + lea(mem(r12, rsi, 1), r12) // r12 = c + 4*jj*cs_c; + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(r11, rsi) // rsi *= cs_b; + lea(mem(rdx, rsi, 1), rdx) // rbx = b + 4*jj*cs_b; + + mov(var(m_iter), r9) // ii = m_iter; + + vxorpd(ymm4, ymm4, ymm4) + vmovapd( ymm4, ymm5) + vmovapd( ymm4, ymm6) + vmovapd( ymm4, ymm7) + vmovapd( ymm4, ymm8) + vmovapd( ymm4, ymm9) + vmovapd( ymm4, ymm10) + vmovapd( ymm4, ymm11) + vmovapd( ymm4, ymm12) + vmovapd( ymm4, ymm13) + vmovapd( ymm4, ymm14) + vmovapd( ymm4, ymm15) + + lea(mem(r12), rcx) // rcx = c_iijj; + lea(mem(r14), rax) // rax = a_ii; + lea(mem(rdx), rbx) // rbx = b_jj; + + + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c + lea(mem(r8, r8, 4), rbp) // rbp = 5*rs_a + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4_BLOCK3) // if i == 0, jump to code that + // contains the k_iter4 loop. + + label(.DLOOPKITER16_BLOCK3) // MAIN LOOP + + // ---------------------------------- iteration 0 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + + // ---------------------------------- iteration 1 + + SUBITER_K4_3x4(rax, rbx) + + // ---------------------------------- iteration 2 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + + // ---------------------------------- iteration 3 + + SUBITER_K4_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16_BLOCK3) // iterate again if i != 0. + + label(.DCONSIDKITER4_BLOCK3) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1_BLOCK3) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + label(.DLOOPKITER4_BLOCK3) // EDGE LOOP (ymm) + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4_BLOCK3) // iterate again if i != 0. + + label(.DCONSIDKLEFT1_BLOCK3) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM_BLOCK3) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + label(.DLOOPKLEFT1_BLOCK3) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + SUBITER_K1_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1_BLOCK3) // iterate again if i != 0. + + label(.DPOSTACCUM_BLOCK3) + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + // xmm4[0] = sum(ymm4); xmm4[1] = sum(ymm7) + // xmm4[2] = sum(ymm10); xmm4[3] = sum(ymm13) + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + // xmm5[0] = sum(ymm5); xmm5[1] = sum(ymm8) + // xmm5[2] = sum(ymm11); xmm5[3] = sum(ymm14) + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + // xmm6[0] = sum(ymm6); xmm6[1] = sum(ymm9) + // xmm6[2] = sum(ymm12); xmm6[3] = sum(ymm15) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO_BLOCK3) // if ZF = 1, jump to beta == 0 case + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + + + jmp(.DDONE_BLOCK3) // jump to end. + + label(.DBETAZERO_BLOCK3) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx)) + + + label(.DDONE_BLOCK3) + + lea(mem(r12, rdi, 2), r12) // + lea(mem(r12, rdi, 1), r12) // c_ii = r12 += 3*rs_c + + lea(mem(r14, r8, 2), r14) // + lea(mem(r14, r8, 1), r14) // a_ii = r14 += 3*rs_a + + dec(r9) // ii -= 1; + +// ----------------------- Block 4 + vxorpd(ymm4, ymm4, ymm4) + vmovapd( ymm4, ymm5) + vmovapd( ymm4, ymm6) + vmovapd( ymm4, ymm7) + vmovapd( ymm4, ymm8) + vmovapd( ymm4, ymm9) + vmovapd( ymm4, ymm10) + vmovapd( ymm4, ymm11) + vmovapd( ymm4, ymm12) + vmovapd( ymm4, ymm13) + vmovapd( ymm4, ymm14) + vmovapd( ymm4, ymm15) + + lea(mem(r12), rcx) // rcx = c_iijj; + lea(mem(r14), rax) // rax = a_ii; + lea(mem(rdx), rbx) // rbx = b_jj; + + + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c + lea(mem(r8, r8, 4), rbp) // rbp = 5*rs_a + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4_BLOCK4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + label(.DLOOPKITER16_BLOCK4) // MAIN LOOP + + // ---------------------------------- iteration 0 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + + // ---------------------------------- iteration 1 + + SUBITER_K4_3x4(rax, rbx) + + // ---------------------------------- iteration 2 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + + // ---------------------------------- iteration 3 + + SUBITER_K4_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16_BLOCK4) // iterate again if i != 0. + + label(.DCONSIDKITER4_BLOCK4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1_BLOCK4) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + label(.DLOOPKITER4_BLOCK4) // EDGE LOOP (ymm) + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4_BLOCK4) // iterate again if i != 0. + + label(.DCONSIDKLEFT1_BLOCK4) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM_BLOCK4) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + label(.DLOOPKLEFT1_BLOCK4) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + SUBITER_K1_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1_BLOCK4) // iterate again if i != 0. + + label(.DPOSTACCUM_BLOCK4) + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + // xmm4[0] = sum(ymm4); xmm4[1] = sum(ymm7) + // xmm4[2] = sum(ymm10); xmm4[3] = sum(ymm13) + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + // xmm5[0] = sum(ymm5); xmm5[1] = sum(ymm8) + // xmm5[2] = sum(ymm11); xmm5[3] = sum(ymm14) + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + // xmm6[0] = sum(ymm6); xmm6[1] = sum(ymm9) + // xmm6[2] = sum(ymm12); xmm6[3] = sum(ymm15) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO_BLOCK4) // if ZF = 1, jump to beta == 0 case + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + + + jmp(.DDONE_BLOCK4) // jump to end. + + label(.DBETAZERO_BLOCK4) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx)) + + + label(.DDONE_BLOCK4) + + vzeroupper() + + end_asm( + : // output operands (none) + : // input operands + [m_iter] "m" (m_iter), + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7); +} + +/* + +Following kernel computes the 6x8 block for the Upper vairant(U) of gemmt where +m_offset in 24x24 block is 12 and n_offset is 8(12x8) +(12x8)_U + +the region marked with 'x' is computed by following kernel +the region marked with '-' is not computed + + <-- n_off_24 -- > + 8 9 10 11 12 13 14 15 + +ā†‘ 12 - - - - x x x x +| 13 - - - - - x x x +m 14 - - - - - - x x +off 15 - - - - - - - x +24 16 - - - - - - - - +| 17 - - - - - - - - +ā†“ + + +*/ +void bli_dgemmsup_rd_haswell_asm_6x8m_12x8_U + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_7); + + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t m_iter = m0 / 3; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + begin_asm() + + mov(var(rs_a), r8) // load rs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + mov(var(cs_b), r11) // load cs_b + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a + + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + mov(imm(4), r15) // jj = 0; + +// ----------------------- Block 3 + + mov(var(a), r14) // load address of a + mov(var(b), rdx) // load address of b + mov(var(c), r12) // load address of c + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(imm(1*8), rsi) // rsi *= cs_c*sizeof(double) = 1*8 + lea(mem(r12, rsi, 1), r12) // r12 = c + 4*jj*cs_c; + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(r11, rsi) // rsi *= cs_b; + lea(mem(rdx, rsi, 1), rdx) // rbx = b + 4*jj*cs_b; + + mov(var(m_iter), r9) // ii = m_iter; + + vxorpd(ymm4, ymm4, ymm4) + vmovapd( ymm4, ymm5) + vmovapd( ymm4, ymm6) + vmovapd( ymm4, ymm7) + vmovapd( ymm4, ymm8) + vmovapd( ymm4, ymm9) + vmovapd( ymm4, ymm10) + vmovapd( ymm4, ymm11) + vmovapd( ymm4, ymm12) + vmovapd( ymm4, ymm13) + vmovapd( ymm4, ymm14) + vmovapd( ymm4, ymm15) + + lea(mem(r12), rcx) // rcx = c_iijj; + lea(mem(r14), rax) // rax = a_ii; + lea(mem(rdx), rbx) // rbx = b_jj; + + + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c + lea(mem(r8, r8, 4), rbp) // rbp = 5*rs_a + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4_BLOCK3) // if i == 0, jump to code that + // contains the k_iter4 loop. + + label(.DLOOPKITER16_BLOCK3) // MAIN LOOP + + // ---------------------------------- iteration 0 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + + // ---------------------------------- iteration 1 + SUBITER_K4_3x4(rax, rbx) + + // ---------------------------------- iteration 2 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + + // ---------------------------------- iteration 3 + SUBITER_K4_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16_BLOCK3) // iterate again if i != 0. + + label(.DCONSIDKITER4_BLOCK3) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1_BLOCK3) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + label(.DLOOPKITER4_BLOCK3) // EDGE LOOP (ymm) + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4_BLOCK3) // iterate again if i != 0. + + label(.DCONSIDKLEFT1_BLOCK3) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM_BLOCK3) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + label(.DLOOPKLEFT1_BLOCK3) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + SUBITER_K1_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1_BLOCK3) // iterate again if i != 0. + + label(.DPOSTACCUM_BLOCK3) + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + // xmm4[0] = sum(ymm4); xmm4[1] = sum(ymm7) + // xmm4[2] = sum(ymm10); xmm4[3] = sum(ymm13) + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + // xmm5[0] = sum(ymm5); xmm5[1] = sum(ymm8) + // xmm5[2] = sum(ymm11); xmm5[3] = sum(ymm14) + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + // xmm6[0] = sum(ymm6); xmm6[1] = sum(ymm9) + // xmm6[2] = sum(ymm12); xmm6[3] = sum(ymm15) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO_BLOCK3) // if ZF = 1, jump to beta == 0 case + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vextractf128(imm(1), ymm5, xmm1 ) + vmovhpd(xmm5, mem(rcx, 1*8)) + vmovupd(xmm1, mem(rcx ,2*8)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vextractf128(imm(1), ymm6, xmm1 ) + vmovupd(xmm1, mem(rcx, 2*8)) + + + jmp(.DDONE_BLOCK3) // jump to end. + + label(.DBETAZERO_BLOCK3) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vextractf128(imm(1), ymm5, xmm1 ) + vmovhpd(xmm5, mem(rcx, 1*8)) + vmovupd(xmm1, mem(rcx ,2*8)) + + add(rdi, rcx) + + vextractf128(imm(1), ymm6, xmm1 ) + vmovupd(xmm1, mem(rcx, 2*8)) + + + + label(.DDONE_BLOCK3) + + lea(mem(r12, rdi, 2), r12) // + lea(mem(r12, rdi, 1), r12) // c_ii = r12 += 3*rs_c + + lea(mem(r14, r8, 2), r14) // + lea(mem(r14, r8, 1), r14) // a_ii = r14 += 3*rs_a + + dec(r9) // ii -= 1; + +// ----------------------- Block 4 + vxorpd(ymm4, ymm4, ymm4) + vmovapd( ymm4, ymm5) + vmovapd( ymm4, ymm6) + vmovapd( ymm4, ymm7) + vmovapd( ymm4, ymm8) + vmovapd( ymm4, ymm9) + vmovapd( ymm4, ymm10) + vmovapd( ymm4, ymm11) + vmovapd( ymm4, ymm12) + vmovapd( ymm4, ymm13) + vmovapd( ymm4, ymm14) + vmovapd( ymm4, ymm15) + + lea(mem(r12), rcx) // rcx = c_iijj; + lea(mem(r14), rax) // rax = a_ii; + lea(mem(rdx), rbx) // rbx = b_jj; + + + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c + lea(mem(r8, r8, 4), rbp) // rbp = 5*rs_a + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4_BLOCK4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + label(.DLOOPKITER16_BLOCK4) // MAIN LOOP + + // ---------------------------------- iteration 0 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_1x4(rax, rbx) + + // ---------------------------------- iteration 1 + SUBITER_K4_1x4(rax, rbx) + + // ---------------------------------- iteration 2 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_1x4(rax, rbx) + + // ---------------------------------- iteration 3 + SUBITER_K4_1x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16_BLOCK4) // iterate again if i != 0. + + label(.DCONSIDKITER4_BLOCK4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1_BLOCK4) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + label(.DLOOPKITER4_BLOCK4) // EDGE LOOP (ymm) + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_1x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4_BLOCK4) // iterate again if i != 0. + + label(.DCONSIDKLEFT1_BLOCK4) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM_BLOCK4) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + label(.DLOOPKLEFT1_BLOCK4) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + SUBITER_K1_1x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1_BLOCK4) // iterate again if i != 0. + + label(.DPOSTACCUM_BLOCK4) + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + // xmm4[0] = sum(ymm4); xmm4[1] = sum(ymm7) + // xmm4[2] = sum(ymm10); xmm4[3] = sum(ymm13) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO_BLOCK4) // if ZF = 1, jump to beta == 0 case + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vextractf128(imm(1), ymm4, xmm1 ) + vmovhpd(xmm1, mem(rcx ,3*8)) + + + jmp(.DDONE_BLOCK4) // jump to end. + + label(.DBETAZERO_BLOCK4) + + + vextractf128(imm(1), ymm4, xmm1 ) + vmovhpd(xmm1, mem(rcx ,3*8)) + + + label(.DDONE_BLOCK4) + + vzeroupper() + + end_asm( + : // output operands (none) + : // input operands + [m_iter] "m" (m_iter), + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7); +} + +/* + +Following kernel computes the 6x8 block for the Upper vairant(U) of gemmt where +m_offset in 24x24 block is 12 and n_offset is 16(12x16) +(12x16)_U + + +the region marked with 'x' is computed by following kernel +the region marked with '-' is not computed + + <-- n_off_24 -- > + 16 17 18 19 20 21 22 23 + +ā†‘ 12 x x x x x x x x +| 13 x x x x x x x x +m 14 x x x x x x x x +off 15 x x x x x x x x +24 16 x x x x x x x x +| 17 - x x x x x x x +ā†“ + + +*/ +void bli_dgemmsup_rd_haswell_asm_6x8m_12x16_U + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_7); + + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t m_iter = m0 / 3; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + begin_asm() + + mov(var(rs_a), r8) // load rs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + mov(var(cs_b), r11) // load cs_b + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a + + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + mov(imm(0), r15) // jj = 0; + +// ----------------------- Block 1 + + mov(var(a), r14) // load address of a + mov(var(b), rdx) // load address of b + mov(var(c), r12) // load address of c + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(imm(1*8), rsi) // rsi *= cs_c*sizeof(double) = 1*8 + lea(mem(r12, rsi, 1), r12) // r12 = c + 4*jj*cs_c; + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(r11, rsi) // rsi *= cs_b; + lea(mem(rdx, rsi, 1), rdx) // rbx = b + 4*jj*cs_b; + + mov(var(m_iter), r9) // ii = m_iter; + + vxorpd(ymm4, ymm4, ymm4) + vmovapd( ymm4, ymm5) + vmovapd( ymm4, ymm6) + vmovapd( ymm4, ymm7) + vmovapd( ymm4, ymm8) + vmovapd( ymm4, ymm9) + vmovapd( ymm4, ymm10) + vmovapd( ymm4, ymm11) + vmovapd( ymm4, ymm12) + vmovapd( ymm4, ymm13) + vmovapd( ymm4, ymm14) + vmovapd( ymm4, ymm15) + + lea(mem(r12), rcx) // rcx = c_iijj; + lea(mem(r14), rax) // rax = a_ii; + lea(mem(rdx), rbx) // rbx = b_jj; + + + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c + lea(mem(r8, r8, 4), rbp) // rbp = 5*rs_a + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4_BLOCK1) // if i == 0, jump to code that + // contains the k_iter4 loop. + + label(.DLOOPKITER16_BLOCK1) // MAIN LOOP + + // ---------------------------------- iteration 0 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + + // ---------------------------------- iteration 1 + SUBITER_K4_3x4(rax, rbx) + + // ---------------------------------- iteration 2 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + + // ---------------------------------- iteration 3 + SUBITER_K4_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16_BLOCK1) // iterate again if i != 0. + + label(.DCONSIDKITER4_BLOCK1) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1_BLOCK1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + label(.DLOOPKITER4_BLOCK1) // EDGE LOOP (ymm) + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4_BLOCK1) // iterate again if i != 0. + + label(.DCONSIDKLEFT1_BLOCK1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM_BLOCK1) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + label(.DLOOPKLEFT1_BLOCK1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + SUBITER_K1_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1_BLOCK1) // iterate again if i != 0. + + label(.DPOSTACCUM_BLOCK1) + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + // xmm4[0] = sum(ymm4); xmm4[1] = sum(ymm7) + // xmm4[2] = sum(ymm10); xmm4[3] = sum(ymm13) + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + // xmm5[0] = sum(ymm5); xmm5[1] = sum(ymm8) + // xmm5[2] = sum(ymm11); xmm5[3] = sum(ymm14) + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + // xmm6[0] = sum(ymm6); xmm6[1] = sum(ymm9) + // xmm6[2] = sum(ymm12); xmm6[3] = sum(ymm15) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO_BLOCK1) // if ZF = 1, jump to beta == 0 case + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + + + jmp(.DDONE_BLOCK1) // jump to end. + + label(.DBETAZERO_BLOCK1) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx)) + + + label(.DDONE_BLOCK1) + + lea(mem(r12, rdi, 2), r12) // + lea(mem(r12, rdi, 1), r12) // c_ii = r12 += 3*rs_c + + lea(mem(r14, r8, 2), r14) // + lea(mem(r14, r8, 1), r14) // a_ii = r14 += 3*rs_a + + dec(r9) // ii -= 1; + +// ----------------------- Block 2 + vxorpd(ymm4, ymm4, ymm4) + vmovapd( ymm4, ymm5) + vmovapd( ymm4, ymm6) + vmovapd( ymm4, ymm7) + vmovapd( ymm4, ymm8) + vmovapd( ymm4, ymm9) + vmovapd( ymm4, ymm10) + vmovapd( ymm4, ymm11) + vmovapd( ymm4, ymm12) + vmovapd( ymm4, ymm13) + vmovapd( ymm4, ymm14) + vmovapd( ymm4, ymm15) + + lea(mem(r12), rcx) // rcx = c_iijj; + lea(mem(r14), rax) // rax = a_ii; + lea(mem(rdx), rbx) // rbx = b_jj; + + + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c + lea(mem(r8, r8, 4), rbp) // rbp = 5*rs_a + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4_BLOCK2) // if i == 0, jump to code that + // contains the k_iter4 loop. + + label(.DLOOPKITER16_BLOCK2) // MAIN LOOP + + // ---------------------------------- iteration 0 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + + // ---------------------------------- iteration 1 + SUBITER_K4_3x4(rax, rbx) + + // ---------------------------------- iteration 2 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + + // ---------------------------------- iteration 3 + SUBITER_K4_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16_BLOCK2) // iterate again if i != 0. + + label(.DCONSIDKITER4_BLOCK2) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1_BLOCK2) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + label(.DLOOPKITER4_BLOCK2) // EDGE LOOP (ymm) + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4_BLOCK2) // iterate again if i != 0. + + label(.DCONSIDKLEFT1_BLOCK2) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM_BLOCK2) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + label(.DLOOPKLEFT1_BLOCK2) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + SUBITER_K1_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1_BLOCK2) // iterate again if i != 0. + + label(.DPOSTACCUM_BLOCK2) + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + // xmm4[0] = sum(ymm4); xmm4[1] = sum(ymm7) + // xmm4[2] = sum(ymm10); xmm4[3] = sum(ymm13) + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + // xmm5[0] = sum(ymm5); xmm5[1] = sum(ymm8) + // xmm5[2] = sum(ymm11); xmm5[3] = sum(ymm14) + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + // xmm6[0] = sum(ymm6); xmm6[1] = sum(ymm9) + // xmm6[2] = sum(ymm12); xmm6[3] = sum(ymm15) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO_BLOCK2) // if ZF = 1, jump to beta == 0 case + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vextractf128(imm(1), ymm6, xmm1 ) + vmovhpd(xmm6, mem(rcx, 1*8)) + vmovupd(xmm1, mem(rcx, 2*8)) + + + jmp(.DDONE_BLOCK2) // jump to end. + + label(.DBETAZERO_BLOCK2) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vextractf128(imm(1), ymm6, xmm1 ) + vmovhpd(xmm6, mem(rcx, 1*8)) + vmovupd(xmm1, mem(rcx, 2*8)) + + + label(.DDONE_BLOCK2) + + lea(mem(r12, rdi, 2), r12) // + lea(mem(r12, rdi, 1), r12) // c_ii = r12 += 3*rs_c + + lea(mem(r14, r8, 2), r14) // + lea(mem(r14, r8, 1), r14) // a_ii = r14 += 3*rs_a + + dec(r9) // ii -= 1; + + add(imm(4), r15) // jj += 4; + +// ----------------------- Block 3 + mov(var(a), r14) // load address of a + mov(var(b), rdx) // load address of b + mov(var(c), r12) // load address of c + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(imm(1*8), rsi) // rsi *= cs_c*sizeof(double) = 1*8 + lea(mem(r12, rsi, 1), r12) // r12 = c + 4*jj*cs_c; + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(r11, rsi) // rsi *= cs_b; + lea(mem(rdx, rsi, 1), rdx) // rbx = b + 4*jj*cs_b; + + mov(var(m_iter), r9) // ii = m_iter; + + vxorpd(ymm4, ymm4, ymm4) + vmovapd( ymm4, ymm5) + vmovapd( ymm4, ymm6) + vmovapd( ymm4, ymm7) + vmovapd( ymm4, ymm8) + vmovapd( ymm4, ymm9) + vmovapd( ymm4, ymm10) + vmovapd( ymm4, ymm11) + vmovapd( ymm4, ymm12) + vmovapd( ymm4, ymm13) + vmovapd( ymm4, ymm14) + vmovapd( ymm4, ymm15) + + lea(mem(r12), rcx) // rcx = c_iijj; + lea(mem(r14), rax) // rax = a_ii; + lea(mem(rdx), rbx) // rbx = b_jj; + + + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c + lea(mem(r8, r8, 4), rbp) // rbp = 5*rs_a + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4_BLOCK3) // if i == 0, jump to code that + // contains the k_iter4 loop. + + label(.DLOOPKITER16_BLOCK3) // MAIN LOOP + + // ---------------------------------- iteration 0 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + + // ---------------------------------- iteration 1 + SUBITER_K4_3x4(rax, rbx) + + // ---------------------------------- iteration 2 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + + // ---------------------------------- iteration 3 + SUBITER_K4_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16_BLOCK3) // iterate again if i != 0. + + label(.DCONSIDKITER4_BLOCK3) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1_BLOCK3) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + label(.DLOOPKITER4_BLOCK3) // EDGE LOOP (ymm) + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4_BLOCK3) // iterate again if i != 0. + + label(.DCONSIDKLEFT1_BLOCK3) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM_BLOCK3) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + label(.DLOOPKLEFT1_BLOCK3) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + SUBITER_K1_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1_BLOCK3) // iterate again if i != 0. + + label(.DPOSTACCUM_BLOCK3) + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + // xmm4[0] = sum(ymm4); xmm4[1] = sum(ymm7) + // xmm4[2] = sum(ymm10); xmm4[3] = sum(ymm13) + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + // xmm5[0] = sum(ymm5); xmm5[1] = sum(ymm8) + // xmm5[2] = sum(ymm11); xmm5[3] = sum(ymm14) + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + // xmm6[0] = sum(ymm6); xmm6[1] = sum(ymm9) + // xmm6[2] = sum(ymm12); xmm6[3] = sum(ymm15) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO_BLOCK3) // if ZF = 1, jump to beta == 0 case + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + + + jmp(.DDONE_BLOCK3) // jump to end. + + label(.DBETAZERO_BLOCK3) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx)) + + + label(.DDONE_BLOCK3) + + lea(mem(r12, rdi, 2), r12) // + lea(mem(r12, rdi, 1), r12) // c_ii = r12 += 3*rs_c + + lea(mem(r14, r8, 2), r14) // + lea(mem(r14, r8, 1), r14) // a_ii = r14 += 3*rs_a + + dec(r9) // ii -= 1; + +// ----------------------- Block 4 + vxorpd(ymm4, ymm4, ymm4) + vmovapd( ymm4, ymm5) + vmovapd( ymm4, ymm6) + vmovapd( ymm4, ymm7) + vmovapd( ymm4, ymm8) + vmovapd( ymm4, ymm9) + vmovapd( ymm4, ymm10) + vmovapd( ymm4, ymm11) + vmovapd( ymm4, ymm12) + vmovapd( ymm4, ymm13) + vmovapd( ymm4, ymm14) + vmovapd( ymm4, ymm15) + + lea(mem(r12), rcx) // rcx = c_iijj; + lea(mem(r14), rax) // rax = a_ii; + lea(mem(rdx), rbx) // rbx = b_jj; + + + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c + lea(mem(r8, r8, 4), rbp) // rbp = 5*rs_a + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4_BLOCK4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + label(.DLOOPKITER16_BLOCK4) // MAIN LOOP + + // ---------------------------------- iteration 0 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + + // ---------------------------------- iteration 1 + SUBITER_K4_3x4(rax, rbx) + + // ---------------------------------- iteration 2 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + + // ---------------------------------- iteration 3 + SUBITER_K4_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16_BLOCK4) // iterate again if i != 0. + + label(.DCONSIDKITER4_BLOCK4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1_BLOCK4) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + label(.DLOOPKITER4_BLOCK4) // EDGE LOOP (ymm) + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4_BLOCK4) // iterate again if i != 0. + + label(.DCONSIDKLEFT1_BLOCK4) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM_BLOCK4) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + label(.DLOOPKLEFT1_BLOCK4) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + SUBITER_K1_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1_BLOCK4) // iterate again if i != 0. + + label(.DPOSTACCUM_BLOCK4) + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + // xmm4[0] = sum(ymm4); xmm4[1] = sum(ymm7) + // xmm4[2] = sum(ymm10); xmm4[3] = sum(ymm13) + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + // xmm5[0] = sum(ymm5); xmm5[1] = sum(ymm8) + // xmm5[2] = sum(ymm11); xmm5[3] = sum(ymm14) + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + // xmm6[0] = sum(ymm6); xmm6[1] = sum(ymm9) + // xmm6[2] = sum(ymm12); xmm6[3] = sum(ymm15) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO_BLOCK4) // if ZF = 1, jump to beta == 0 case + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + + + jmp(.DDONE_BLOCK4) // jump to end. + + label(.DBETAZERO_BLOCK4) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx)) + + + label(.DDONE_BLOCK4) + + vzeroupper() + + end_asm( + : // output operands (none) + : // input operands + [m_iter] "m" (m_iter), + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7); +} + +/* + +Following kernel computes the 6x8 block for the Upper vairant(U) of gemmt where +m_offset in 24x24 block is 12 and n_offset is 8(12x8) +(12x8)_U + +the region marked with 'x' is computed by following kernel +the region marked with '-' is not computed + + <-- n_off_24 -- > + 8 9 10 11 12 13 14 15 + +ā†‘ 12 - - - - x x x x +| 13 - - - - - x x x +m 14 - - - - - - x x +off 15 - - - - - - - x +24 16 - - - - - - - - +| 17 - - - - - - - - +ā†“ + + +*/ +void bli_dgemmsup_rd_haswell_asm_6x8m_18x16_U + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_7); + + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t m_iter = m0 / 3; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + begin_asm() + + mov(var(rs_a), r8) // load rs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + mov(var(cs_b), r11) // load cs_b + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a + + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + mov(imm(0), r15) // jj = 0; + +// ----------------------- Block 1 + mov(var(a), r14) // load address of a + mov(var(b), rdx) // load address of b + mov(var(c), r12) // load address of c + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(imm(1*8), rsi) // rsi *= cs_c*sizeof(double) = 1*8 + lea(mem(r12, rsi, 1), r12) // r12 = c + 4*jj*cs_c; + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(r11, rsi) // rsi *= cs_b; + lea(mem(rdx, rsi, 1), rdx) // rbx = b + 4*jj*cs_b; + + mov(var(m_iter), r9) // ii = m_iter; + + vxorpd(ymm4, ymm4, ymm4) + vmovapd( ymm4, ymm5) + vmovapd( ymm4, ymm6) + vmovapd( ymm4, ymm7) + vmovapd( ymm4, ymm8) + vmovapd( ymm4, ymm9) + vmovapd( ymm4, ymm10) + vmovapd( ymm4, ymm11) + vmovapd( ymm4, ymm12) + vmovapd( ymm4, ymm13) + vmovapd( ymm4, ymm14) + vmovapd( ymm4, ymm15) + + lea(mem(r12), rcx) // rcx = c_iijj; + lea(mem(r14), rax) // rax = a_ii; + lea(mem(rdx), rbx) // rbx = b_jj; + + + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c + lea(mem(r8, r8, 4), rbp) // rbp = 5*rs_a + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4_BLOCK1) // if i == 0, jump to code that + // contains the k_iter4 loop. + + label(.DLOOPKITER16_BLOCK1) // MAIN LOOP + + // ---------------------------------- iteration 0 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_2x4(rax, rbx) + + // ---------------------------------- iteration 1 + SUBITER_K4_2x4(rax, rbx) + + // ---------------------------------- iteration 2 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_2x4(rax, rbx) + + // ---------------------------------- iteration 3 + SUBITER_K4_2x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16_BLOCK1) // iterate again if i != 0. + + label(.DCONSIDKITER4_BLOCK1) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1_BLOCK1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + label(.DLOOPKITER4_BLOCK1) // EDGE LOOP (ymm) + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_2x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4_BLOCK1) // iterate again if i != 0. + + label(.DCONSIDKLEFT1_BLOCK1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM_BLOCK1) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + label(.DLOOPKLEFT1_BLOCK1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + SUBITER_K1_2x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1_BLOCK1) // iterate again if i != 0. + + label(.DPOSTACCUM_BLOCK1) + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + // xmm4[0] = sum(ymm4); xmm4[1] = sum(ymm7) + // xmm4[2] = sum(ymm10); xmm4[3] = sum(ymm13) + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + // xmm5[0] = sum(ymm5); xmm5[1] = sum(ymm8) + // xmm5[2] = sum(ymm11); xmm5[3] = sum(ymm14) + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + // xmm6[0] = sum(ymm6); xmm6[1] = sum(ymm9) + // xmm6[2] = sum(ymm12); xmm6[3] = sum(ymm15) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO_BLOCK1) // if ZF = 1, jump to beta == 0 case + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vextractf128(imm(1), ymm4, xmm1 ) + vmovupd(xmm1, mem(rcx, 2*8)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vextractf128(imm(1), ymm5, xmm1 ) + vmovhpd(xmm1, mem(rcx, 3*8)) + + // vfmadd231pd(mem(rcx), ymm3, ymm6) + // vmovupd(ymm6, mem(rcx)) + + + jmp(.DDONE_BLOCK1) // jump to end. + + label(.DBETAZERO_BLOCK1) + + + vextractf128(imm(1), ymm4, xmm1 ) + vmovupd(xmm1, mem(rcx, 2*8)) + add(rdi, rcx) + + vextractf128(imm(1), ymm5, xmm1 ) + vmovhpd(xmm1, mem(rcx, 3*8)) + + + label(.DDONE_BLOCK1) +// ----------------------- Block 2 + add(imm(4), r15) // jj += 4; + +// ----------------------- Block 3 + mov(var(a), r14) // load address of a + mov(var(b), rdx) // load address of b + mov(var(c), r12) // load address of c + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(imm(1*8), rsi) // rsi *= cs_c*sizeof(double) = 1*8 + lea(mem(r12, rsi, 1), r12) // r12 = c + 4*jj*cs_c; + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(r11, rsi) // rsi *= cs_b; + lea(mem(rdx, rsi, 1), rdx) // rbx = b + 4*jj*cs_b; + + mov(var(m_iter), r9) // ii = m_iter; + + vxorpd(ymm4, ymm4, ymm4) + vmovapd( ymm4, ymm5) + vmovapd( ymm4, ymm6) + vmovapd( ymm4, ymm7) + vmovapd( ymm4, ymm8) + vmovapd( ymm4, ymm9) + vmovapd( ymm4, ymm10) + vmovapd( ymm4, ymm11) + vmovapd( ymm4, ymm12) + vmovapd( ymm4, ymm13) + vmovapd( ymm4, ymm14) + vmovapd( ymm4, ymm15) + + lea(mem(r12), rcx) // rcx = c_iijj; + lea(mem(r14), rax) // rax = a_ii; + lea(mem(rdx), rbx) // rbx = b_jj; + + + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c + lea(mem(r8, r8, 4), rbp) // rbp = 5*rs_a + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4_BLOCK3) // if i == 0, jump to code that + // contains the k_iter4 loop. + + label(.DLOOPKITER16_BLOCK3) // MAIN LOOP + + // ---------------------------------- iteration 0 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + + // ---------------------------------- iteration 1 + SUBITER_K4_3x4(rax, rbx) + + // ---------------------------------- iteration 2 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + + // ---------------------------------- iteration 3 + SUBITER_K4_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16_BLOCK3) // iterate again if i != 0. + + label(.DCONSIDKITER4_BLOCK3) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1_BLOCK3) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + label(.DLOOPKITER4_BLOCK3) // EDGE LOOP (ymm) + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4_BLOCK3) // iterate again if i != 0. + + label(.DCONSIDKLEFT1_BLOCK3) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM_BLOCK3) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + label(.DLOOPKLEFT1_BLOCK3) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + SUBITER_K1_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1_BLOCK3) // iterate again if i != 0. + + label(.DPOSTACCUM_BLOCK3) + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + // xmm4[0] = sum(ymm4); xmm4[1] = sum(ymm7) + // xmm4[2] = sum(ymm10); xmm4[3] = sum(ymm13) + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + // xmm5[0] = sum(ymm5); xmm5[1] = sum(ymm8) + // xmm5[2] = sum(ymm11); xmm5[3] = sum(ymm14) + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + // xmm6[0] = sum(ymm6); xmm6[1] = sum(ymm9) + // xmm6[2] = sum(ymm12); xmm6[3] = sum(ymm15) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO_BLOCK3) // if ZF = 1, jump to beta == 0 case + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + + + jmp(.DDONE_BLOCK3) // jump to end. + + label(.DBETAZERO_BLOCK3) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx)) + + + label(.DDONE_BLOCK3) + + lea(mem(r12, rdi, 2), r12) // + lea(mem(r12, rdi, 1), r12) // c_ii = r12 += 3*rs_c + + lea(mem(r14, r8, 2), r14) // + lea(mem(r14, r8, 1), r14) // a_ii = r14 += 3*rs_a + + dec(r9) // ii -= 1; + +// ----------------------- Block 4 + vxorpd(ymm4, ymm4, ymm4) + vmovapd( ymm4, ymm5) + vmovapd( ymm4, ymm6) + vmovapd( ymm4, ymm7) + vmovapd( ymm4, ymm8) + vmovapd( ymm4, ymm9) + vmovapd( ymm4, ymm10) + vmovapd( ymm4, ymm11) + vmovapd( ymm4, ymm12) + vmovapd( ymm4, ymm13) + vmovapd( ymm4, ymm14) + vmovapd( ymm4, ymm15) + + lea(mem(r12), rcx) // rcx = c_iijj; + lea(mem(r14), rax) // rax = a_ii; + lea(mem(rdx), rbx) // rbx = b_jj; + + + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c + lea(mem(r8, r8, 4), rbp) // rbp = 5*rs_a + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4_BLOCK4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + label(.DLOOPKITER16_BLOCK4) // MAIN LOOP + + // ---------------------------------- iteration 0 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + + // ---------------------------------- iteration 1 + SUBITER_K4_3x4(rax, rbx) + + // ---------------------------------- iteration 2 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + + // ---------------------------------- iteration 3 + SUBITER_K4_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16_BLOCK4) // iterate again if i != 0. + + label(.DCONSIDKITER4_BLOCK4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1_BLOCK4) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + label(.DLOOPKITER4_BLOCK4) // EDGE LOOP (ymm) + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4_BLOCK4) // iterate again if i != 0. + + label(.DCONSIDKLEFT1_BLOCK4) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM_BLOCK4) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + label(.DLOOPKLEFT1_BLOCK4) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + SUBITER_K1_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1_BLOCK4) // iterate again if i != 0. + + label(.DPOSTACCUM_BLOCK4) + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + // xmm4[0] = sum(ymm4); xmm4[1] = sum(ymm7) + // xmm4[2] = sum(ymm10); xmm4[3] = sum(ymm13) + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + // xmm5[0] = sum(ymm5); xmm5[1] = sum(ymm8) + // xmm5[2] = sum(ymm11); xmm5[3] = sum(ymm14) + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + // xmm6[0] = sum(ymm6); xmm6[1] = sum(ymm9) + // xmm6[2] = sum(ymm12); xmm6[3] = sum(ymm15) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO_BLOCK4) // if ZF = 1, jump to beta == 0 case + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vextractf128(imm(1), ymm4, xmm1 ) + vmovhpd(xmm4, mem(rcx, 1*8)) + vmovupd(xmm1, mem(rcx, 2*8)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vextractf128(imm(1), ymm5, xmm1 ) + vmovupd(xmm1, mem(rcx, 2*8)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vextractf128(imm(1), ymm6, xmm1 ) + vmovhpd(xmm1, mem(rcx, 3*8)) + + + jmp(.DDONE_BLOCK4) // jump to end. + + label(.DBETAZERO_BLOCK4) + + + vextractf128(imm(1), ymm4, xmm1 ) + vmovhpd(xmm4, mem(rcx, 1*8)) + vmovupd(xmm1, mem(rcx, 2*8)) + add(rdi, rcx) + + vextractf128(imm(1), ymm5, xmm1 ) + vmovupd(xmm1, mem(rcx, 2*8)) + add(rdi, rcx) + + vextractf128(imm(1), ymm6, xmm1 ) + vmovhpd(xmm1, mem(rcx, 3*8)) + + + label(.DDONE_BLOCK4) + + vzeroupper() + + end_asm( + : // output operands (none) + : // input operands + [m_iter] "m" (m_iter), + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7); +} + +/* + +Following kernel computes the 6x8 block for the Lower vairant(L) of gemmt where +m_offset in 24x24 block is 0 and n_offset is 0(0x0) +(0x0)_L + +the region marked with 'x' is computed by following kernel +the region marked with '-' is not computed + + <-- n_off_24 -- > + 0 1 2 3 4 5 6 7 + +ā†‘ 0 x - - - - - - - +| 1 x x - - - - - - +m 2 x x x - - - - - +off 3 x x x x - - - - +24 4 x x x x x - - - +| 5 x x x x x x - - +ā†“ + + +*/ +void bli_dgemmsup_rd_haswell_asm_6x8m_0x0_L + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_7); + + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t m_iter = m0 / 3; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + begin_asm() + + mov(var(rs_a), r8) // load rs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + mov(var(cs_b), r11) // load cs_b + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a + + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + mov(imm(0), r15) // jj = 0; + +// ----------------------- Block 1 + + mov(var(a), r14) // load address of a + mov(var(b), rdx) // load address of b + mov(var(c), r12) // load address of c + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(imm(1*8), rsi) // rsi *= cs_c*sizeof(double) = 1*8 + lea(mem(r12, rsi, 1), r12) // r12 = c + 4*jj*cs_c; + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(r11, rsi) // rsi *= cs_b; + lea(mem(rdx, rsi, 1), rdx) // rbx = b + 4*jj*cs_b; + + mov(var(m_iter), r9) // ii = m_iter; + + vxorpd(ymm4, ymm4, ymm4) + vmovapd( ymm4, ymm5) + vmovapd( ymm4, ymm6) + vmovapd( ymm4, ymm7) + vmovapd( ymm4, ymm8) + vmovapd( ymm4, ymm9) + vmovapd( ymm4, ymm10) + vmovapd( ymm4, ymm11) + vmovapd( ymm4, ymm12) + vmovapd( ymm4, ymm13) + vmovapd( ymm4, ymm14) + vmovapd( ymm4, ymm15) + + lea(mem(r12), rcx) // rcx = c_iijj; + lea(mem(r14), rax) // rax = a_ii; + lea(mem(rdx), rbx) // rbx = b_jj; + + + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c + lea(mem(r8, r8, 4), rbp) // rbp = 5*rs_a + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4_BLOCK1) // if i == 0, jump to code that + // contains the k_iter4 loop. + + label(.DLOOPKITER16_BLOCK1) // MAIN LOOP + + // ---------------------------------- iteration 0 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + + // ---------------------------------- iteration 1 + SUBITER_K4_3x4(rax, rbx) + + // ---------------------------------- iteration 2 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + + // ---------------------------------- iteration 3 + SUBITER_K4_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16_BLOCK1) // iterate again if i != 0. + + label(.DCONSIDKITER4_BLOCK1) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1_BLOCK1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + label(.DLOOPKITER4_BLOCK1) // EDGE LOOP (ymm) + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4_BLOCK1) // iterate again if i != 0. + + label(.DCONSIDKLEFT1_BLOCK1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM_BLOCK1) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + label(.DLOOPKLEFT1_BLOCK1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + SUBITER_K1_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1_BLOCK1) // iterate again if i != 0. + + label(.DPOSTACCUM_BLOCK1) + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + // xmm4[0] = sum(ymm4); xmm4[1] = sum(ymm7) + // xmm4[2] = sum(ymm10); xmm4[3] = sum(ymm13) + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + // xmm5[0] = sum(ymm5); xmm5[1] = sum(ymm8) + // xmm5[2] = sum(ymm11); xmm5[3] = sum(ymm14) + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + // xmm6[0] = sum(ymm6); xmm6[1] = sum(ymm9) + // xmm6[2] = sum(ymm12); xmm6[3] = sum(ymm15) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO_BLOCK1) // if ZF = 1, jump to beta == 0 case + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovlpd(xmm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(xmm5, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(xmm6, mem(rcx)) + vextractf128(imm(1), ymm6, xmm1 ) + vmovlpd(xmm1, mem(rcx, 2*8)) + + + jmp(.DDONE_BLOCK1) // jump to end. + + label(.DBETAZERO_BLOCK1) + + + vmovlpd(xmm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm5, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm6, mem(rcx)) + vextractf128(imm(1), ymm6, xmm1 ) + vmovlpd(xmm1, mem(rcx, 2*8)) + + + label(.DDONE_BLOCK1) + + lea(mem(r12, rdi, 2), r12) // + lea(mem(r12, rdi, 1), r12) // c_ii = r12 += 3*rs_c + + lea(mem(r14, r8, 2), r14) // + lea(mem(r14, r8, 1), r14) // a_ii = r14 += 3*rs_a + + dec(r9) // ii -= 1; + +// ----------------------- Block 2 + vxorpd(ymm4, ymm4, ymm4) + vmovapd( ymm4, ymm5) + vmovapd( ymm4, ymm6) + vmovapd( ymm4, ymm7) + vmovapd( ymm4, ymm8) + vmovapd( ymm4, ymm9) + vmovapd( ymm4, ymm10) + vmovapd( ymm4, ymm11) + vmovapd( ymm4, ymm12) + vmovapd( ymm4, ymm13) + vmovapd( ymm4, ymm14) + vmovapd( ymm4, ymm15) + + lea(mem(r12), rcx) // rcx = c_iijj; + lea(mem(r14), rax) // rax = a_ii; + lea(mem(rdx), rbx) // rbx = b_jj; + + + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c + lea(mem(r8, r8, 4), rbp) // rbp = 5*rs_a + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4_BLOCK2) // if i == 0, jump to code that + // contains the k_iter4 loop. + + label(.DLOOPKITER16_BLOCK2) // MAIN LOOP + + // ---------------------------------- iteration 0 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + + // ---------------------------------- iteration 1 + SUBITER_K4_3x4(rax, rbx) + + // ---------------------------------- iteration 2 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + + // ---------------------------------- iteration 3 + SUBITER_K4_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16_BLOCK2) // iterate again if i != 0. + + label(.DCONSIDKITER4_BLOCK2) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1_BLOCK2) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + label(.DLOOPKITER4_BLOCK2) // EDGE LOOP (ymm) + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4_BLOCK2) // iterate again if i != 0. + + label(.DCONSIDKLEFT1_BLOCK2) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM_BLOCK2) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + label(.DLOOPKLEFT1_BLOCK2) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + SUBITER_K1_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1_BLOCK2) // iterate again if i != 0. + + label(.DPOSTACCUM_BLOCK2) + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + // xmm4[0] = sum(ymm4); xmm4[1] = sum(ymm7) + // xmm4[2] = sum(ymm10); xmm4[3] = sum(ymm13) + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + // xmm5[0] = sum(ymm5); xmm5[1] = sum(ymm8) + // xmm5[2] = sum(ymm11); xmm5[3] = sum(ymm14) + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + // xmm6[0] = sum(ymm6); xmm6[1] = sum(ymm9) + // xmm6[2] = sum(ymm12); xmm6[3] = sum(ymm15) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO_BLOCK2) // if ZF = 1, jump to beta == 0 case + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + + + jmp(.DDONE_BLOCK2) // jump to end. + + label(.DBETAZERO_BLOCK2) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx)) + + + label(.DDONE_BLOCK2) + + lea(mem(r12, rdi, 2), r12) // + lea(mem(r12, rdi, 1), r12) // c_ii = r12 += 3*rs_c + + lea(mem(r14, r8, 2), r14) // + lea(mem(r14, r8, 1), r14) // a_ii = r14 += 3*rs_a + + dec(r9) // ii -= 1; + + add(imm(4), r15) // jj += 4; + +// ----------------------- Block 3 + mov(var(a), r14) // load address of a + mov(var(b), rdx) // load address of b + mov(var(c), r12) // load address of c + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(imm(1*8), rsi) // rsi *= cs_c*sizeof(double) = 1*8 + lea(mem(r12, rsi, 1), r12) // r12 = c + 4*jj*cs_c; + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(r11, rsi) // rsi *= cs_b; + lea(mem(rdx, rsi, 1), rdx) // rbx = b + 4*jj*cs_b; + + mov(var(m_iter), r9) // ii = m_iter; + + lea(mem(r12, rdi, 2), r12) // + lea(mem(r12, rdi, 1), r12) // c_ii = r12 += 3*rs_c + + lea(mem(r14, r8, 2), r14) // + lea(mem(r14, r8, 1), r14) // a_ii = r14 += 3*rs_a + + dec(r9) // ii -= 1; + +// ----------------------- Block 4 + vxorpd(ymm4, ymm4, ymm4) + vmovapd( ymm4, ymm5) + vmovapd( ymm4, ymm6) + vmovapd( ymm4, ymm7) + vmovapd( ymm4, ymm8) + vmovapd( ymm4, ymm9) + vmovapd( ymm4, ymm10) + vmovapd( ymm4, ymm11) + vmovapd( ymm4, ymm12) + vmovapd( ymm4, ymm13) + vmovapd( ymm4, ymm14) + vmovapd( ymm4, ymm15) + + lea(mem(r12), rcx) // rcx = c_iijj; + lea(mem(r14), rax) // rax = a_ii; + lea(mem(rdx), rbx) // rbx = b_jj; + + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c + lea(mem(r8, r8, 4), rbp) // rbp = 5*rs_a + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4_BLOCK4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + label(.DLOOPKITER16_BLOCK4) // MAIN LOOP + + // ---------------------------------- iteration 0 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + // ---------------------------------- iteration 1 + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + // ---------------------------------- iteration 2 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + // ---------------------------------- iteration 3 + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16_BLOCK4) // iterate again if i != 0. + + label(.DCONSIDKITER4_BLOCK4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1_BLOCK4) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + label(.DLOOPKITER4_BLOCK4) // EDGE LOOP (ymm) + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4_BLOCK4) // iterate again if i != 0. + + label(.DCONSIDKLEFT1_BLOCK4) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM_BLOCK4) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + label(.DLOOPKLEFT1_BLOCK4) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + vmovsd(mem(rax, r8, 1), xmm1) + vmovsd(mem(rax, r8, 2), xmm2) + add(imm(1*8), rax) // a += 1*cs_a = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1_BLOCK4) // iterate again if i != 0. + + label(.DPOSTACCUM_BLOCK4) + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + // xmm5[0] = sum(ymm5); xmm5[1] = sum(ymm8) + // xmm5[2] = sum(ymm11); xmm5[3] = sum(ymm14) + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + // xmm6[0] = sum(ymm6); xmm6[1] = sum(ymm9) + // xmm6[2] = sum(ymm12); xmm6[3] = sum(ymm15) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO_BLOCK4) // if ZF = 1, jump to beta == 0 case + + + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovlpd(xmm5, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(xmm6, mem(rcx)) + + + jmp(.DDONE_BLOCK4) // jump to end. + + label(.DBETAZERO_BLOCK4) + + + add(rdi, rcx) + + vmovlpd(xmm5, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm6, mem(rcx)) + + + label(.DDONE_BLOCK4) + + vzeroupper() + + end_asm( + : // output operands (none) + : // input operands + [m_iter] "m" (m_iter), + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7); +} + +/* + +Following kernel computes the 6x8 block for the Lower vairant(L) of gemmt where +m_offset in 24x24 block is 6 and n_offset is 0(6x0) +(6x0)_L + + +the region marked with 'x' is computed by following kernel +the region marked with '-' is not computed + + <-- n_off_24 -- > + 0 1 2 3 4 5 6 7 + +ā†‘ 6 x x x x x x x - +| 7 x x x x x x x x +m 8 x x x x x x x x +off 9 x x x x x x x x +24 10 x x x x x x x x +| 11 x x x x x x x x +ā†“ + + +*/ +void bli_dgemmsup_rd_haswell_asm_6x8m_6x0_L + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_7); + + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t m_iter = m0 / 3; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + begin_asm() + + mov(var(rs_a), r8) // load rs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + mov(var(cs_b), r11) // load cs_b + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a + + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + mov(imm(0), r15) // jj = 0; + +// ----------------------- Block 1 + + mov(var(a), r14) // load address of a + mov(var(b), rdx) // load address of b + mov(var(c), r12) // load address of c + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(imm(1*8), rsi) // rsi *= cs_c*sizeof(double) = 1*8 + lea(mem(r12, rsi, 1), r12) // r12 = c + 4*jj*cs_c; + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(r11, rsi) // rsi *= cs_b; + lea(mem(rdx, rsi, 1), rdx) // rbx = b + 4*jj*cs_b; + + mov(var(m_iter), r9) // ii = m_iter; + + vxorpd(ymm4, ymm4, ymm4) + vmovapd( ymm4, ymm5) + vmovapd( ymm4, ymm6) + vmovapd( ymm4, ymm7) + vmovapd( ymm4, ymm8) + vmovapd( ymm4, ymm9) + vmovapd( ymm4, ymm10) + vmovapd( ymm4, ymm11) + vmovapd( ymm4, ymm12) + vmovapd( ymm4, ymm13) + vmovapd( ymm4, ymm14) + vmovapd( ymm4, ymm15) + + lea(mem(r12), rcx) // rcx = c_iijj; + lea(mem(r14), rax) // rax = a_ii; + lea(mem(rdx), rbx) // rbx = b_jj; + + + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c + lea(mem(r8, r8, 4), rbp) // rbp = 5*rs_a + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4_BLOCK1) // if i == 0, jump to code that + // contains the k_iter4 loop. + + label(.DLOOPKITER16_BLOCK1) // MAIN LOOP + + // ---------------------------------- iteration 0 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + + // ---------------------------------- iteration 1 + SUBITER_K4_3x4(rax, rbx) + + // ---------------------------------- iteration 2 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + + // ---------------------------------- iteration 3 + SUBITER_K4_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16_BLOCK1) // iterate again if i != 0. + + label(.DCONSIDKITER4_BLOCK1) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1_BLOCK1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + label(.DLOOPKITER4_BLOCK1) // EDGE LOOP (ymm) + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4_BLOCK1) // iterate again if i != 0. + + label(.DCONSIDKLEFT1_BLOCK1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM_BLOCK1) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + label(.DLOOPKLEFT1_BLOCK1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + SUBITER_K1_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1_BLOCK1) // iterate again if i != 0. + + label(.DPOSTACCUM_BLOCK1) + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + // xmm4[0] = sum(ymm4); xmm4[1] = sum(ymm7) + // xmm4[2] = sum(ymm10); xmm4[3] = sum(ymm13) + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + // xmm5[0] = sum(ymm5); xmm5[1] = sum(ymm8) + // xmm5[2] = sum(ymm11); xmm5[3] = sum(ymm14) + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + // xmm6[0] = sum(ymm6); xmm6[1] = sum(ymm9) + // xmm6[2] = sum(ymm12); xmm6[3] = sum(ymm15) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO_BLOCK1) // if ZF = 1, jump to beta == 0 case + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + + + jmp(.DDONE_BLOCK1) // jump to end. + + label(.DBETAZERO_BLOCK1) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx)) + + + label(.DDONE_BLOCK1) + + lea(mem(r12, rdi, 2), r12) // + lea(mem(r12, rdi, 1), r12) // c_ii = r12 += 3*rs_c + + lea(mem(r14, r8, 2), r14) // + lea(mem(r14, r8, 1), r14) // a_ii = r14 += 3*rs_a + + dec(r9) // ii -= 1; + +// ----------------------- Block 2 + vxorpd(ymm4, ymm4, ymm4) + vmovapd( ymm4, ymm5) + vmovapd( ymm4, ymm6) + vmovapd( ymm4, ymm7) + vmovapd( ymm4, ymm8) + vmovapd( ymm4, ymm9) + vmovapd( ymm4, ymm10) + vmovapd( ymm4, ymm11) + vmovapd( ymm4, ymm12) + vmovapd( ymm4, ymm13) + vmovapd( ymm4, ymm14) + vmovapd( ymm4, ymm15) + + lea(mem(r12), rcx) // rcx = c_iijj; + lea(mem(r14), rax) // rax = a_ii; + lea(mem(rdx), rbx) // rbx = b_jj; + + + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c + lea(mem(r8, r8, 4), rbp) // rbp = 5*rs_a + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4_BLOCK2) // if i == 0, jump to code that + // contains the k_iter4 loop. + + label(.DLOOPKITER16_BLOCK2) // MAIN LOOP + + // ---------------------------------- iteration 0 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + + // ---------------------------------- iteration 1 + SUBITER_K4_3x4(rax, rbx) + + // ---------------------------------- iteration 2 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + + // ---------------------------------- iteration 3 + SUBITER_K4_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16_BLOCK2) // iterate again if i != 0. + + label(.DCONSIDKITER4_BLOCK2) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1_BLOCK2) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + label(.DLOOPKITER4_BLOCK2) // EDGE LOOP (ymm) + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4_BLOCK2) // iterate again if i != 0. + + label(.DCONSIDKLEFT1_BLOCK2) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM_BLOCK2) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + label(.DLOOPKLEFT1_BLOCK2) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + SUBITER_K1_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1_BLOCK2) // iterate again if i != 0. + + label(.DPOSTACCUM_BLOCK2) + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + // xmm4[0] = sum(ymm4); xmm4[1] = sum(ymm7) + // xmm4[2] = sum(ymm10); xmm4[3] = sum(ymm13) + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + // xmm5[0] = sum(ymm5); xmm5[1] = sum(ymm8) + // xmm5[2] = sum(ymm11); xmm5[3] = sum(ymm14) + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + // xmm6[0] = sum(ymm6); xmm6[1] = sum(ymm9) + // xmm6[2] = sum(ymm12); xmm6[3] = sum(ymm15) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO_BLOCK2) // if ZF = 1, jump to beta == 0 case + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + + + jmp(.DDONE_BLOCK2) // jump to end. + + label(.DBETAZERO_BLOCK2) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx)) + + + label(.DDONE_BLOCK2) + + lea(mem(r12, rdi, 2), r12) // + lea(mem(r12, rdi, 1), r12) // c_ii = r12 += 3*rs_c + + lea(mem(r14, r8, 2), r14) // + lea(mem(r14, r8, 1), r14) // a_ii = r14 += 3*rs_a + + dec(r9) // ii -= 1; + + add(imm(4), r15) // jj += 4; + +// ----------------------- Block 3 + mov(var(a), r14) // load address of a + mov(var(b), rdx) // load address of b + mov(var(c), r12) // load address of c + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(imm(1*8), rsi) // rsi *= cs_c*sizeof(double) = 1*8 + lea(mem(r12, rsi, 1), r12) // r12 = c + 4*jj*cs_c; + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(r11, rsi) // rsi *= cs_b; + lea(mem(rdx, rsi, 1), rdx) // rbx = b + 4*jj*cs_b; + + mov(var(m_iter), r9) // ii = m_iter; + + vxorpd(ymm4, ymm4, ymm4) + vmovapd( ymm4, ymm5) + vmovapd( ymm4, ymm6) + vmovapd( ymm4, ymm7) + vmovapd( ymm4, ymm8) + vmovapd( ymm4, ymm9) + vmovapd( ymm4, ymm10) + vmovapd( ymm4, ymm11) + vmovapd( ymm4, ymm12) + vmovapd( ymm4, ymm13) + vmovapd( ymm4, ymm14) + vmovapd( ymm4, ymm15) + + lea(mem(r12), rcx) // rcx = c_iijj; + lea(mem(r14), rax) // rax = a_ii; + lea(mem(rdx), rbx) // rbx = b_jj; + + + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c + lea(mem(r8, r8, 4), rbp) // rbp = 5*rs_a + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4_BLOCK3) // if i == 0, jump to code that + // contains the k_iter4 loop. + + label(.DLOOPKITER16_BLOCK3) // MAIN LOOP + + // ---------------------------------- iteration 0 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + + // ---------------------------------- iteration 1 + SUBITER_K4_3x4(rax, rbx) + + // ---------------------------------- iteration 2 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + + // ---------------------------------- iteration 3 + SUBITER_K4_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16_BLOCK3) // iterate again if i != 0. + + label(.DCONSIDKITER4_BLOCK3) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1_BLOCK3) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + label(.DLOOPKITER4_BLOCK3) // EDGE LOOP (ymm) + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4_BLOCK3) // iterate again if i != 0. + + label(.DCONSIDKLEFT1_BLOCK3) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM_BLOCK3) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + label(.DLOOPKLEFT1_BLOCK3) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + SUBITER_K1_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1_BLOCK3) // iterate again if i != 0. + + label(.DPOSTACCUM_BLOCK3) + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + // xmm4[0] = sum(ymm4); xmm4[1] = sum(ymm7) + // xmm4[2] = sum(ymm10); xmm4[3] = sum(ymm13) + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + // xmm5[0] = sum(ymm5); xmm5[1] = sum(ymm8) + // xmm5[2] = sum(ymm11); xmm5[3] = sum(ymm14) + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + // xmm6[0] = sum(ymm6); xmm6[1] = sum(ymm9) + // xmm6[2] = sum(ymm12); xmm6[3] = sum(ymm15) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO_BLOCK3) // if ZF = 1, jump to beta == 0 case + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vextractf128(imm(1), ymm4, xmm1 ) + vmovupd(xmm4, mem(rcx)) + vmovlpd(xmm1, mem(rcx, 2*8)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + + + jmp(.DDONE_BLOCK3) // jump to end. + + label(.DBETAZERO_BLOCK3) + + + vextractf128(imm(1), ymm4, xmm1 ) + vmovupd(xmm4, mem(rcx)) + vmovlpd(xmm1, mem(rcx, 2*8)) + add(rdi, rcx) + + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx)) + + + label(.DDONE_BLOCK3) + + lea(mem(r12, rdi, 2), r12) // + lea(mem(r12, rdi, 1), r12) // c_ii = r12 += 3*rs_c + + lea(mem(r14, r8, 2), r14) // + lea(mem(r14, r8, 1), r14) // a_ii = r14 += 3*rs_a + + dec(r9) // ii -= 1; + +// ----------------------- Block 4 + vxorpd(ymm4, ymm4, ymm4) + vmovapd( ymm4, ymm5) + vmovapd( ymm4, ymm6) + vmovapd( ymm4, ymm7) + vmovapd( ymm4, ymm8) + vmovapd( ymm4, ymm9) + vmovapd( ymm4, ymm10) + vmovapd( ymm4, ymm11) + vmovapd( ymm4, ymm12) + vmovapd( ymm4, ymm13) + vmovapd( ymm4, ymm14) + vmovapd( ymm4, ymm15) + + lea(mem(r12), rcx) // rcx = c_iijj; + lea(mem(r14), rax) // rax = a_ii; + lea(mem(rdx), rbx) // rbx = b_jj; + + + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c + lea(mem(r8, r8, 4), rbp) // rbp = 5*rs_a + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4_BLOCK4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + label(.DLOOPKITER16_BLOCK4) // MAIN LOOP + + // ---------------------------------- iteration 0 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + + // ---------------------------------- iteration 1 + SUBITER_K4_3x4(rax, rbx) + + // ---------------------------------- iteration 2 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + + // ---------------------------------- iteration 3 + SUBITER_K4_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16_BLOCK4) // iterate again if i != 0. + + label(.DCONSIDKITER4_BLOCK4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1_BLOCK4) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + label(.DLOOPKITER4_BLOCK4) // EDGE LOOP (ymm) + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4_BLOCK4) // iterate again if i != 0. + + label(.DCONSIDKLEFT1_BLOCK4) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM_BLOCK4) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + label(.DLOOPKLEFT1_BLOCK4) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + SUBITER_K1_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1_BLOCK4) // iterate again if i != 0. + + label(.DPOSTACCUM_BLOCK4) + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + // xmm4[0] = sum(ymm4); xmm4[1] = sum(ymm7) + // xmm4[2] = sum(ymm10); xmm4[3] = sum(ymm13) + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + // xmm5[0] = sum(ymm5); xmm5[1] = sum(ymm8) + // xmm5[2] = sum(ymm11); xmm5[3] = sum(ymm14) + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + // xmm6[0] = sum(ymm6); xmm6[1] = sum(ymm9) + // xmm6[2] = sum(ymm12); xmm6[3] = sum(ymm15) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO_BLOCK4) // if ZF = 1, jump to beta == 0 case + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + + + jmp(.DDONE_BLOCK4) // jump to end. + + label(.DBETAZERO_BLOCK4) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx)) + + + label(.DDONE_BLOCK4) + + vzeroupper() + + end_asm( + : // output operands (none) + : // input operands + [m_iter] "m" (m_iter), + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7); +} + +/* + +Following kernel computes the 6x8 block for the Lower vairant(L) of gemmt where +m_offset in 24x24 block is 6 and n_offset is 8(6x8) +(6x8)_L + +the region marked with 'x' is computed by following kernel +the region marked with '-' is not computed + + <-- n_off_24 -- > + 8 9 10 11 12 13 14 15 + +ā†‘ 6 - - - - - - - - +| 7 - - - - - - - - +m 8 x - - - - - - - +off 9 x x - - - - - - +24 10 x x x - - - - - +| 11 x x x x - - - - +ā†“ + + +*/ +void bli_dgemmsup_rd_haswell_asm_6x8m_6x8_L + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_7); + + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t m_iter = m0 / 3; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + begin_asm() + + mov(var(rs_a), r8) // load rs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + mov(var(cs_b), r11) // load cs_b + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a + + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + mov(imm(0), r15) // jj = 0; + +// ----------------------- Block 1 + + mov(var(a), r14) // load address of a + mov(var(b), rdx) // load address of b + mov(var(c), r12) // load address of c + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(imm(1*8), rsi) // rsi *= cs_c*sizeof(double) = 1*8 + lea(mem(r12, rsi, 1), r12) // r12 = c + 4*jj*cs_c; + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(r11, rsi) // rsi *= cs_b; + lea(mem(rdx, rsi, 1), rdx) // rbx = b + 4*jj*cs_b; + + mov(var(m_iter), r9) // ii = m_iter; + + vxorpd(ymm4, ymm4, ymm4) + vmovapd( ymm4, ymm5) + vmovapd( ymm4, ymm6) + vmovapd( ymm4, ymm7) + vmovapd( ymm4, ymm8) + vmovapd( ymm4, ymm9) + vmovapd( ymm4, ymm10) + vmovapd( ymm4, ymm11) + vmovapd( ymm4, ymm12) + vmovapd( ymm4, ymm13) + vmovapd( ymm4, ymm14) + vmovapd( ymm4, ymm15) + + lea(mem(r12), rcx) // rcx = c_iijj; + lea(mem(r14), rax) // rax = a_ii; + lea(mem(rdx), rbx) // rbx = b_jj; + + + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c + lea(mem(r8, r8, 4), rbp) // rbp = 5*rs_a + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4_BLOCK1) // if i == 0, jump to code that + // contains the k_iter4 loop. + + label(.DLOOPKITER16_BLOCK1) // MAIN LOOP + + // ---------------------------------- iteration 0 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm2, ymm3, ymm15) + + // ---------------------------------- iteration 1 + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm2, ymm3, ymm15) + + // ---------------------------------- iteration 2 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm2, ymm3, ymm15) + + // ---------------------------------- iteration 3 + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm2, ymm3, ymm15) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16_BLOCK1) // iterate again if i != 0. + + label(.DCONSIDKITER4_BLOCK1) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1_BLOCK1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + label(.DLOOPKITER4_BLOCK1) // EDGE LOOP (ymm) + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm2, ymm3, ymm15) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4_BLOCK1) // iterate again if i != 0. + + label(.DCONSIDKLEFT1_BLOCK1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM_BLOCK1) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + label(.DLOOPKLEFT1_BLOCK1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + vmovsd(mem(rax, r8, 2), xmm2) + add(imm(1*8), rax) // a += 1*cs_a = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm2, ymm3, ymm15) + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1_BLOCK1) // iterate again if i != 0. + + label(.DPOSTACCUM_BLOCK1) + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + // xmm6[0] = sum(ymm6); xmm6[1] = sum(ymm9) + // xmm6[2] = sum(ymm12); xmm6[3] = sum(ymm15) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm6, ymm6) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO_BLOCK1) // if ZF = 1, jump to beta == 0 case + + + add(rdi, rcx) + add(rdi, rcx) + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovlpd(xmm6, mem(rcx)) + + + jmp(.DDONE_BLOCK1) // jump to end. + + label(.DBETAZERO_BLOCK1) + + add(rdi, rcx) + add(rdi, rcx) + vmovlpd(xmm6, mem(rcx)) + + + + label(.DDONE_BLOCK1) + + lea(mem(r12, rdi, 2), r12) // + lea(mem(r12, rdi, 1), r12) // c_ii = r12 += 3*rs_c + + lea(mem(r14, r8, 2), r14) // + lea(mem(r14, r8, 1), r14) // a_ii = r14 += 3*rs_a + + dec(r9) // ii -= 1; + +// ----------------------- Block 2 + vxorpd(ymm4, ymm4, ymm4) + vmovapd( ymm4, ymm5) + vmovapd( ymm4, ymm6) + vmovapd( ymm4, ymm7) + vmovapd( ymm4, ymm8) + vmovapd( ymm4, ymm9) + vmovapd( ymm4, ymm10) + vmovapd( ymm4, ymm11) + vmovapd( ymm4, ymm12) + vmovapd( ymm4, ymm13) + vmovapd( ymm4, ymm14) + vmovapd( ymm4, ymm15) + + lea(mem(r12), rcx) // rcx = c_iijj; + lea(mem(r14), rax) // rax = a_ii; + lea(mem(rdx), rbx) // rbx = b_jj; + + + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c + lea(mem(r8, r8, 4), rbp) // rbp = 5*rs_a + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4_BLOCK2) // if i == 0, jump to code that + // contains the k_iter4 loop. + + label(.DLOOPKITER16_BLOCK2) // MAIN LOOP + + // ---------------------------------- iteration 0 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + + // ---------------------------------- iteration 1 + SUBITER_K4_3x4(rax, rbx) + + // ---------------------------------- iteration 2 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + + // ---------------------------------- iteration 3 + SUBITER_K4_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16_BLOCK2) // iterate again if i != 0. + + label(.DCONSIDKITER4_BLOCK2) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1_BLOCK2) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + label(.DLOOPKITER4_BLOCK2) // EDGE LOOP (ymm) + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4_BLOCK2) // iterate again if i != 0. + + label(.DCONSIDKLEFT1_BLOCK2) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM_BLOCK2) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + label(.DLOOPKLEFT1_BLOCK2) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + SUBITER_K1_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1_BLOCK2) // iterate again if i != 0. + + label(.DPOSTACCUM_BLOCK2) + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + // xmm4[0] = sum(ymm4); xmm4[1] = sum(ymm7) + // xmm4[2] = sum(ymm10); xmm4[3] = sum(ymm13) + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + // xmm5[0] = sum(ymm5); xmm5[1] = sum(ymm8) + // xmm5[2] = sum(ymm11); xmm5[3] = sum(ymm14) + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + // xmm6[0] = sum(ymm6); xmm6[1] = sum(ymm9) + // xmm6[2] = sum(ymm12); xmm6[3] = sum(ymm15) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO_BLOCK2) // if ZF = 1, jump to beta == 0 case + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(xmm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vextractf128(imm(1), ymm5, xmm1 ) + vmovupd(xmm5, mem(rcx)) + vmovlpd(xmm1, mem(rcx, 2*8)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + + + jmp(.DDONE_BLOCK2) // jump to end. + + label(.DBETAZERO_BLOCK2) + + + vmovupd(xmm4, mem(rcx)) + add(rdi, rcx) + + vextractf128(imm(1), ymm5, xmm1 ) + vmovupd(xmm5, mem(rcx)) + vmovlpd(xmm1, mem(rcx, 2*8)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx)) + + + label(.DDONE_BLOCK2) + + vzeroupper() + + end_asm( + : // output operands (none) + : // input operands + [m_iter] "m" (m_iter), + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7); +} + +/* + +Following kernel computes the 6x8 block for the Lower vairant(L) of gemmt where +m_offset in 24x24 block is 12 and n_offset is 8(12x8) +(12x8)_L + +the region marked with 'x' is computed by following kernel +the region marked with '-' is not computed + + <-- n_off_24 -- > + 8 9 10 11 12 13 14 15 + +ā†‘ 12 x x x x x - - - +| 13 x x x x x x - - +m 14 x x x x x x x - +off 15 x x x x x x x x +24 16 x x x x x x x x +| 17 x x x x x x x x +ā†“ + + +*/ +void bli_dgemmsup_rd_haswell_asm_6x8m_12x8_L + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_7); + + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t m_iter = m0 / 3; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + begin_asm() + + mov(var(rs_a), r8) // load rs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + mov(var(cs_b), r11) // load cs_b + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a + + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + mov(imm(0), r15) // jj = 0; + +// ----------------------- Block 1 + + mov(var(a), r14) // load address of a + mov(var(b), rdx) // load address of b + mov(var(c), r12) // load address of c + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(imm(1*8), rsi) // rsi *= cs_c*sizeof(double) = 1*8 + lea(mem(r12, rsi, 1), r12) // r12 = c + 4*jj*cs_c; + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(r11, rsi) // rsi *= cs_b; + lea(mem(rdx, rsi, 1), rdx) // rbx = b + 4*jj*cs_b; + + mov(var(m_iter), r9) // ii = m_iter; + + vxorpd(ymm4, ymm4, ymm4) + vmovapd( ymm4, ymm5) + vmovapd( ymm4, ymm6) + vmovapd( ymm4, ymm7) + vmovapd( ymm4, ymm8) + vmovapd( ymm4, ymm9) + vmovapd( ymm4, ymm10) + vmovapd( ymm4, ymm11) + vmovapd( ymm4, ymm12) + vmovapd( ymm4, ymm13) + vmovapd( ymm4, ymm14) + vmovapd( ymm4, ymm15) + + lea(mem(r12), rcx) // rcx = c_iijj; + lea(mem(r14), rax) // rax = a_ii; + lea(mem(rdx), rbx) // rbx = b_jj; + + + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c + lea(mem(r8, r8, 4), rbp) // rbp = 5*rs_a + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4_BLOCK1) // if i == 0, jump to code that + // contains the k_iter4 loop. + + label(.DLOOPKITER16_BLOCK1) // MAIN LOOP + + // ---------------------------------- iteration 0 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + + // ---------------------------------- iteration 1 + SUBITER_K4_3x4(rax, rbx) + + // ---------------------------------- iteration 2 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + + // ---------------------------------- iteration 3 + SUBITER_K4_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16_BLOCK1) // iterate again if i != 0. + + label(.DCONSIDKITER4_BLOCK1) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1_BLOCK1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + label(.DLOOPKITER4_BLOCK1) // EDGE LOOP (ymm) + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4_BLOCK1) // iterate again if i != 0. + + label(.DCONSIDKLEFT1_BLOCK1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM_BLOCK1) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + label(.DLOOPKLEFT1_BLOCK1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + SUBITER_K1_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1_BLOCK1) // iterate again if i != 0. + + label(.DPOSTACCUM_BLOCK1) + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + // xmm4[0] = sum(ymm4); xmm4[1] = sum(ymm7) + // xmm4[2] = sum(ymm10); xmm4[3] = sum(ymm13) + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + // xmm5[0] = sum(ymm5); xmm5[1] = sum(ymm8) + // xmm5[2] = sum(ymm11); xmm5[3] = sum(ymm14) + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + // xmm6[0] = sum(ymm6); xmm6[1] = sum(ymm9) + // xmm6[2] = sum(ymm12); xmm6[3] = sum(ymm15) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO_BLOCK1) // if ZF = 1, jump to beta == 0 case + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + + + jmp(.DDONE_BLOCK1) // jump to end. + + label(.DBETAZERO_BLOCK1) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx)) + + label(.DDONE_BLOCK1) + + lea(mem(r12, rdi, 2), r12) // + lea(mem(r12, rdi, 1), r12) // c_ii = r12 += 3*rs_c + + lea(mem(r14, r8, 2), r14) // + lea(mem(r14, r8, 1), r14) // a_ii = r14 += 3*rs_a + + dec(r9) // ii -= 1; + +// ----------------------- Block 2 + vxorpd(ymm4, ymm4, ymm4) + vmovapd( ymm4, ymm5) + vmovapd( ymm4, ymm6) + vmovapd( ymm4, ymm7) + vmovapd( ymm4, ymm8) + vmovapd( ymm4, ymm9) + vmovapd( ymm4, ymm10) + vmovapd( ymm4, ymm11) + vmovapd( ymm4, ymm12) + vmovapd( ymm4, ymm13) + vmovapd( ymm4, ymm14) + vmovapd( ymm4, ymm15) + + lea(mem(r12), rcx) // rcx = c_iijj; + lea(mem(r14), rax) // rax = a_ii; + lea(mem(rdx), rbx) // rbx = b_jj; + + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c + lea(mem(r8, r8, 4), rbp) // rbp = 5*rs_a + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4_BLOCK2) // if i == 0, jump to code that + // contains the k_iter4 loop. + + label(.DLOOPKITER16_BLOCK2) // MAIN LOOP + + // ---------------------------------- iteration 0 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + + // ---------------------------------- iteration 1 + SUBITER_K4_3x4(rax, rbx) + + // ---------------------------------- iteration 2 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + + // ---------------------------------- iteration 3 + SUBITER_K4_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16_BLOCK2) // iterate again if i != 0. + + label(.DCONSIDKITER4_BLOCK2) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1_BLOCK2) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + label(.DLOOPKITER4_BLOCK2) // EDGE LOOP (ymm) + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4_BLOCK2) // iterate again if i != 0. + + label(.DCONSIDKLEFT1_BLOCK2) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM_BLOCK2) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + label(.DLOOPKLEFT1_BLOCK2) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + SUBITER_K1_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1_BLOCK2) // iterate again if i != 0. + + label(.DPOSTACCUM_BLOCK2) + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + // xmm4[0] = sum(ymm4); xmm4[1] = sum(ymm7) + // xmm4[2] = sum(ymm10); xmm4[3] = sum(ymm13) + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + // xmm5[0] = sum(ymm5); xmm5[1] = sum(ymm8) + // xmm5[2] = sum(ymm11); xmm5[3] = sum(ymm14) + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + // xmm6[0] = sum(ymm6); xmm6[1] = sum(ymm9) + // xmm6[2] = sum(ymm12); xmm6[3] = sum(ymm15) + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO_BLOCK2) // if ZF = 1, jump to beta == 0 case + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + + jmp(.DDONE_BLOCK2) // jump to end. + + label(.DBETAZERO_BLOCK2) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx)) + + label(.DDONE_BLOCK2) + + lea(mem(r12, rdi, 2), r12) // + lea(mem(r12, rdi, 1), r12) // c_ii = r12 += 3*rs_c + + lea(mem(r14, r8, 2), r14) // + lea(mem(r14, r8, 1), r14) // a_ii = r14 += 3*rs_a + + dec(r9) // ii -= 1; + + add(imm(4), r15) // jj += 4; + +// ----------------------- Block 3 + mov(var(a), r14) // load address of a + mov(var(b), rdx) // load address of b + mov(var(c), r12) // load address of c + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(imm(1*8), rsi) // rsi *= cs_c*sizeof(double) = 1*8 + lea(mem(r12, rsi, 1), r12) // r12 = c + 4*jj*cs_c; + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(r11, rsi) // rsi *= cs_b; + lea(mem(rdx, rsi, 1), rdx) // rbx = b + 4*jj*cs_b; + + mov(var(m_iter), r9) // ii = m_iter; + + vxorpd(ymm4, ymm4, ymm4) + vmovapd( ymm4, ymm5) + vmovapd( ymm4, ymm6) + vmovapd( ymm4, ymm7) + vmovapd( ymm4, ymm8) + vmovapd( ymm4, ymm9) + vmovapd( ymm4, ymm10) + vmovapd( ymm4, ymm11) + vmovapd( ymm4, ymm12) + vmovapd( ymm4, ymm13) + vmovapd( ymm4, ymm14) + vmovapd( ymm4, ymm15) + + lea(mem(r12), rcx) // rcx = c_iijj; + lea(mem(r14), rax) // rax = a_ii; + lea(mem(rdx), rbx) // rbx = b_jj; + + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c + lea(mem(r8, r8, 4), rbp) // rbp = 5*rs_a + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4_BLOCK3) // if i == 0, jump to code that + // contains the k_iter4 loop. + + label(.DLOOPKITER16_BLOCK3) // MAIN LOOP + + // ---------------------------------- iteration 0 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + + // ---------------------------------- iteration 1 + SUBITER_K4_3x4(rax, rbx) + + // ---------------------------------- iteration 2 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + + // ---------------------------------- iteration 3 + SUBITER_K4_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16_BLOCK3) // iterate again if i != 0. + + label(.DCONSIDKITER4_BLOCK3) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1_BLOCK3) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + label(.DLOOPKITER4_BLOCK3) // EDGE LOOP (ymm) + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4_BLOCK3) // iterate again if i != 0. + + label(.DCONSIDKLEFT1_BLOCK3) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM_BLOCK3) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + label(.DLOOPKLEFT1_BLOCK3) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + SUBITER_K1_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1_BLOCK3) // iterate again if i != 0. + + label(.DPOSTACCUM_BLOCK3) + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + // xmm4[0] = sum(ymm4); xmm4[1] = sum(ymm7) + // xmm4[2] = sum(ymm10); xmm4[3] = sum(ymm13) + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + // xmm5[0] = sum(ymm5); xmm5[1] = sum(ymm8) + // xmm5[2] = sum(ymm11); xmm5[3] = sum(ymm14) + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + // xmm6[0] = sum(ymm6); xmm6[1] = sum(ymm9) + // xmm6[2] = sum(ymm12); xmm6[3] = sum(ymm15) + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO_BLOCK3) // if ZF = 1, jump to beta == 0 case + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovlpd(xmm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(xmm5, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vextractf128(imm(1), ymm6, xmm1 ) + vmovupd(xmm6, mem(rcx)) + vmovlpd(xmm1, mem(rcx, 2*8)) + + jmp(.DDONE_BLOCK3) // jump to end. + + label(.DBETAZERO_BLOCK3) + + + vmovlpd(xmm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm5, mem(rcx)) + add(rdi, rcx) + + vextractf128(imm(1), ymm6, xmm1 ) + vmovupd(xmm6, mem(rcx)) + vmovlpd(xmm1, mem(rcx, 2*8)) + + label(.DDONE_BLOCK3) + + lea(mem(r12, rdi, 2), r12) // + lea(mem(r12, rdi, 1), r12) // c_ii = r12 += 3*rs_c + + lea(mem(r14, r8, 2), r14) // + lea(mem(r14, r8, 1), r14) // a_ii = r14 += 3*rs_a + + dec(r9) // ii -= 1; + +// ----------------------- Block 4 + vxorpd(ymm4, ymm4, ymm4) + vmovapd( ymm4, ymm5) + vmovapd( ymm4, ymm6) + vmovapd( ymm4, ymm7) + vmovapd( ymm4, ymm8) + vmovapd( ymm4, ymm9) + vmovapd( ymm4, ymm10) + vmovapd( ymm4, ymm11) + vmovapd( ymm4, ymm12) + vmovapd( ymm4, ymm13) + vmovapd( ymm4, ymm14) + vmovapd( ymm4, ymm15) + + lea(mem(r12), rcx) // rcx = c_iijj; + lea(mem(r14), rax) // rax = a_ii; + lea(mem(rdx), rbx) // rbx = b_jj; + + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c + lea(mem(r8, r8, 4), rbp) // rbp = 5*rs_a + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4_BLOCK4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + label(.DLOOPKITER16_BLOCK4) // MAIN LOOP + + // ---------------------------------- iteration 0 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + + SUBITER_K4_3x4(rax, rbx) + // ---------------------------------- iteration 1 + SUBITER_K4_3x4(rax, rbx) + // ---------------------------------- iteration 2 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + + SUBITER_K4_3x4(rax, rbx) + // ---------------------------------- iteration 3 + SUBITER_K4_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16_BLOCK4) // iterate again if i != 0. + + label(.DCONSIDKITER4_BLOCK4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1_BLOCK4) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + label(.DLOOPKITER4_BLOCK4) // EDGE LOOP (ymm) + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + + SUBITER_K4_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4_BLOCK4) // iterate again if i != 0. + + label(.DCONSIDKLEFT1_BLOCK4) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM_BLOCK4) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + label(.DLOOPKLEFT1_BLOCK4) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + SUBITER_K1_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1_BLOCK4) // iterate again if i != 0. + + label(.DPOSTACCUM_BLOCK4) + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + // xmm4[0] = sum(ymm4); xmm4[1] = sum(ymm7) + // xmm4[2] = sum(ymm10); xmm4[3] = sum(ymm13) + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + // xmm5[0] = sum(ymm5); xmm5[1] = sum(ymm8) + // xmm5[2] = sum(ymm11); xmm5[3] = sum(ymm14) + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + // xmm6[0] = sum(ymm6); xmm6[1] = sum(ymm9) + // xmm6[2] = sum(ymm12); xmm6[3] = sum(ymm15) + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO_BLOCK4) // if ZF = 1, jump to beta == 0 case + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + + + jmp(.DDONE_BLOCK4) // jump to end. + + label(.DBETAZERO_BLOCK4) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx)) + + + label(.DDONE_BLOCK4) + + vzeroupper() + + end_asm( + : // output operands (none) + : // input operands + [m_iter] "m" (m_iter), + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7); +} + +/* + +Following kernel computes the 6x8 block for the Lower vairant(L) of gemmt where +m_offset in 24x24 block is 12 and n_offset is 16(12x16) +(12x16)_L + + +the region marked with 'x' is computed by following kernel +the region marked with '-' is not computed + + <-- n_off_24 -- > + 16 17 18 19 20 21 22 23 + +ā†‘ 12 - - - - - - - - +| 13 - - - - - - - - +m 14 - - - - - - - - +off 15 - - - - - - - - +24 16 x - - - - - - - +| 17 x x - - - - - - +ā†“ + + +*/ +void bli_dgemmsup_rd_haswell_asm_6x8m_12x16_L + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_7); + + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t m_iter = m0 / 3; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + begin_asm() + + mov(var(rs_a), r8) // load rs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + mov(var(cs_b), r11) // load cs_b + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a + + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + mov(imm(0), r15) // jj = 0; + +// ----------------------- Block 1 + + mov(var(a), r14) // load address of a + mov(var(b), rdx) // load address of b + mov(var(c), r12) // load address of c + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(imm(1*8), rsi) // rsi *= cs_c*sizeof(double) = 1*8 + lea(mem(r12, rsi, 1), r12) // r12 = c + 4*jj*cs_c; + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(r11, rsi) // rsi *= cs_b; + lea(mem(rdx, rsi, 1), rdx) // rbx = b + 4*jj*cs_b; + + mov(var(m_iter), r9) // ii = m_iter; + + lea(mem(r12, rdi, 2), r12) // + lea(mem(r12, rdi, 1), r12) // c_ii = r12 += 3*rs_c + + lea(mem(r14, r8, 2), r14) // + lea(mem(r14, r8, 1), r14) // a_ii = r14 += 3*rs_a + + dec(r9) // ii -= 1; + +// ----------------------- Block 2 + vxorpd(ymm4, ymm4, ymm4) + vmovapd( ymm4, ymm5) + vmovapd( ymm4, ymm6) + vmovapd( ymm4, ymm7) + vmovapd( ymm4, ymm8) + vmovapd( ymm4, ymm9) + vmovapd( ymm4, ymm10) + vmovapd( ymm4, ymm11) + vmovapd( ymm4, ymm12) + vmovapd( ymm4, ymm13) + vmovapd( ymm4, ymm14) + vmovapd( ymm4, ymm15) + + lea(mem(r12), rcx) // rcx = c_iijj; + lea(mem(r14), rax) // rax = a_ii; + lea(mem(rdx), rbx) // rbx = b_jj; + + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c + lea(mem(r8, r8, 4), rbp) // rbp = 5*rs_a + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4_BLOCK2) // if i == 0, jump to code that + // contains the k_iter4 loop. + + label(.DLOOPKITER16_BLOCK2) // MAIN LOOP + + // ---------------------------------- iteration 0 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + // ---------------------------------- iteration 1 + + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + // ---------------------------------- iteration 2 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + // ---------------------------------- iteration 3 + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16_BLOCK2) // iterate again if i != 0. + + label(.DCONSIDKITER4_BLOCK2) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1_BLOCK2) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + label(.DLOOPKITER4_BLOCK2) // EDGE LOOP (ymm) + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4_BLOCK2) // iterate again if i != 0. + + label(.DCONSIDKLEFT1_BLOCK2) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM_BLOCK2) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + label(.DLOOPKLEFT1_BLOCK2) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax, r8, 1), xmm1) + vmovsd(mem(rax, r8, 2), xmm2) + add(imm(1*8), rax) // a += 1*cs_a = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1_BLOCK2) // iterate again if i != 0. + + label(.DPOSTACCUM_BLOCK2) + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + // xmm5[0] = sum(ymm5); xmm5[1] = sum(ymm8) + // xmm5[2] = sum(ymm11); xmm5[3] = sum(ymm14) + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + // xmm6[0] = sum(ymm6); xmm6[1] = sum(ymm9) + // xmm6[2] = sum(ymm12); xmm6[3] = sum(ymm15) + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO_BLOCK2) // if ZF = 1, jump to beta == 0 case + + + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovlpd(xmm5, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(xmm6, mem(rcx)) + + + jmp(.DDONE_BLOCK2) // jump to end. + + label(.DBETAZERO_BLOCK2) + + + add(rdi, rcx) + + vmovlpd(xmm5, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm6, mem(rcx)) + + label(.DDONE_BLOCK2) + + vzeroupper() + + end_asm( + : // output operands (none) + : // input operands + [m_iter] "m" (m_iter), + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7); +} + +/* + +Following kernel computes the 6x8 block for the Lower vairant(L) of gemmt where +m_offset in 24x24 block is 12, n_offset is 16(12x16) and m_offset is 18, n_offset is 16 (18x16) +(16x12)+(16x18)_L + +the region marked with 'x' is computed by following kernel +the region marked with '-' is not computed + + <-- n_off_24 -- > + 16 17 18 19 20 21 22 23 + +ā†‘ 12 - - - - - - - - +| 13 - - - - - - - - +m 14 - - - - - - - - +off 15 - - - - - - - - +24 16 x - - - - - - - +| 17 x x - - - - - - +ā†“ +ā†‘ 18 x x x - - - - - +| 19 x x x x - - - - +m 20 x x x x x - - - +off 21 x x x x x x - - +24 22 x x x x x x x - +| 23 x x x x x x x x +ā†“ + + +*/ +void bli_dgemmsup_rd_haswell_asm_6x8m_16x12_combined_L + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_7); + + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t m_iter = m0 / 3; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + begin_asm() + + mov(var(rs_a), r8) // load rs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + mov(var(cs_b), r11) // load cs_b + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a + + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + mov(imm(0), r15) // jj = 0; + +// ----------------------- Block 1 + + mov(var(a), r14) // load address of a + mov(var(b), rdx) // load address of b + mov(var(c), r12) // load address of c + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(imm(1*8), rsi) // rsi *= cs_c*sizeof(double) = 1*8 + lea(mem(r12, rsi, 1), r12) // r12 = c + 4*jj*cs_c; + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(r11, rsi) // rsi *= cs_b; + lea(mem(rdx, rsi, 1), rdx) // rbx = b + 4*jj*cs_b; + + mov(var(m_iter), r9) // ii = m_iter; + + lea(mem(r12, rdi, 2), r12) // + lea(mem(r12, rdi, 1), r12) // c_ii = r12 += 3*rs_c + + lea(mem(r14, r8, 2), r14) // + lea(mem(r14, r8, 1), r14) // a_ii = r14 += 3*rs_a + + dec(r9) // ii -= 1; + +// ----------------------- Block 2 + vxorpd(ymm4, ymm4, ymm4) + vmovapd( ymm4, ymm5) + vmovapd( ymm4, ymm6) + vmovapd( ymm4, ymm7) + vmovapd( ymm4, ymm8) + vmovapd( ymm4, ymm9) + vmovapd( ymm4, ymm10) + vmovapd( ymm4, ymm11) + vmovapd( ymm4, ymm12) + vmovapd( ymm4, ymm13) + vmovapd( ymm4, ymm14) + vmovapd( ymm4, ymm15) + + lea(mem(r12), rcx) // rcx = c_iijj; + lea(mem(r14), rax) // rax = a_ii; + lea(mem(rdx), rbx) // rbx = b_jj; + + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c + lea(mem(r8, r8, 4), rbp) // rbp = 5*rs_a + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4_BLOCK2) // if i == 0, jump to code that + // contains the k_iter4 loop. + + label(.DLOOPKITER16_BLOCK2) // MAIN LOOP + + // ---------------------------------- iteration 0 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + // ---------------------------------- iteration 1 + + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + // ---------------------------------- iteration 2 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + // ---------------------------------- iteration 3 + + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16_BLOCK2) // iterate again if i != 0. + + label(.DCONSIDKITER4_BLOCK2) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1_BLOCK2) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + label(.DLOOPKITER4_BLOCK2) // EDGE LOOP (ymm) + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4_BLOCK2) // iterate again if i != 0. + + label(.DCONSIDKLEFT1_BLOCK2) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM_BLOCK2) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + label(.DLOOPKLEFT1_BLOCK2) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax, r8, 1), xmm1) + vmovsd(mem(rax, r8, 2), xmm2) + add(imm(1*8), rax) // a += 1*cs_a = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1_BLOCK2) // iterate again if i != 0. + + label(.DPOSTACCUM_BLOCK2) + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + // xmm5[0] = sum(ymm5); xmm5[1] = sum(ymm8) + // xmm5[2] = sum(ymm11); xmm5[3] = sum(ymm14) + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + // xmm6[0] = sum(ymm6); xmm6[1] = sum(ymm9) + // xmm6[2] = sum(ymm12); xmm6[3] = sum(ymm15) + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO_BLOCK2) // if ZF = 1, jump to beta == 0 case + + + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovlpd(xmm5, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(xmm6, mem(rcx)) + + jmp(.DDONE_BLOCK2) // jump to end. + + label(.DBETAZERO_BLOCK2) + + + add(rdi, rcx) + + vmovlpd(xmm5, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm6, mem(rcx)) + + label(.DDONE_BLOCK2) + + vzeroupper() + + end_asm( + : // output operands (none) + : // input operands + [m_iter] "m" (m_iter), + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + a += 6 * rs_a; + c += 6 * rs_c; + begin_asm() + + mov(var(rs_a), r8) // load rs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + mov(var(cs_b), r11) // load cs_b + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a + + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + mov(imm(0), r15) // jj = 0; + +// ----------------------- Block 1 + + mov(var(a), r14) // load address of a + mov(var(b), rdx) // load address of b + mov(var(c), r12) // load address of c + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(imm(1*8), rsi) // rsi *= cs_c*sizeof(double) = 1*8 + lea(mem(r12, rsi, 1), r12) // r12 = c + 4*jj*cs_c; + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(r11, rsi) // rsi *= cs_b; + lea(mem(rdx, rsi, 1), rdx) // rbx = b + 4*jj*cs_b; + + mov(var(m_iter), r9) // ii = m_iter; + + vxorpd(ymm4, ymm4, ymm4) + vmovapd( ymm4, ymm5) + vmovapd( ymm4, ymm6) + vmovapd( ymm4, ymm7) + vmovapd( ymm4, ymm8) + vmovapd( ymm4, ymm9) + vmovapd( ymm4, ymm10) + vmovapd( ymm4, ymm11) + vmovapd( ymm4, ymm12) + vmovapd( ymm4, ymm13) + vmovapd( ymm4, ymm14) + vmovapd( ymm4, ymm15) + + lea(mem(r12), rcx) // rcx = c_iijj; + lea(mem(r14), rax) // rax = a_ii; + lea(mem(rdx), rbx) // rbx = b_jj; + + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c + lea(mem(r8, r8, 4), rbp) // rbp = 5*rs_a + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4_BLOCK1) // if i == 0, jump to code that + // contains the k_iter4 loop. + + label(.DLOOPKITER16_BLOCK1) // MAIN LOOP + + // ---------------------------------- iteration 0 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + // ---------------------------------- iteration 1 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + // ---------------------------------- iteration 2 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + // ---------------------------------- iteration 3 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16_BLOCK1) // iterate again if i != 0. + + label(.DCONSIDKITER4_BLOCK1) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1_BLOCK1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + label(.DLOOPKITER4_BLOCK1) // EDGE LOOP (ymm) + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4_BLOCK1) // iterate again if i != 0. + + label(.DCONSIDKLEFT1_BLOCK1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM_BLOCK1) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + label(.DLOOPKLEFT1_BLOCK1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax ), xmm0) + vmovsd(mem(rax, r8, 1), xmm1) + vmovsd(mem(rax, r8, 2), xmm2) + add(imm(1*8), rax) // a += 1*cs_a = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1_BLOCK1) // iterate again if i != 0. + + label(.DPOSTACCUM_BLOCK1) + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + // xmm4[0] = sum(ymm4); xmm4[1] = sum(ymm7) + // xmm4[2] = sum(ymm10); xmm4[3] = sum(ymm13) + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + // xmm5[0] = sum(ymm5); xmm5[1] = sum(ymm8) + // xmm5[2] = sum(ymm11); xmm5[3] = sum(ymm14) + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + // xmm6[0] = sum(ymm6); xmm6[1] = sum(ymm9) + // xmm6[2] = sum(ymm12); xmm6[3] = sum(ymm15) + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO_BLOCK1) // if ZF = 1, jump to beta == 0 case + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vextractf128(imm(1), ymm4, xmm1 ) + vmovupd(xmm4, mem(rcx)) + vmovlpd(xmm1, mem(rcx, 2*8)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + + jmp(.DDONE_BLOCK1) // jump to end. + + label(.DBETAZERO_BLOCK1) + + + vextractf128(imm(1), ymm4, xmm1 ) + vmovupd(xmm4, mem(rcx)) + vmovlpd(xmm1, mem(rcx, 2*8)) + add(rdi, rcx) + + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx)) + + label(.DDONE_BLOCK1) + + lea(mem(r12, rdi, 2), r12) // + lea(mem(r12, rdi, 1), r12) // c_ii = r12 += 3*rs_c + + lea(mem(r14, r8, 2), r14) // + lea(mem(r14, r8, 1), r14) // a_ii = r14 += 3*rs_a + + dec(r9) // ii -= 1; + +// ----------------------- Block 2 + vxorpd(ymm4, ymm4, ymm4) + vmovapd( ymm4, ymm5) + vmovapd( ymm4, ymm6) + vmovapd( ymm4, ymm7) + vmovapd( ymm4, ymm8) + vmovapd( ymm4, ymm9) + vmovapd( ymm4, ymm10) + vmovapd( ymm4, ymm11) + vmovapd( ymm4, ymm12) + vmovapd( ymm4, ymm13) + vmovapd( ymm4, ymm14) + vmovapd( ymm4, ymm15) + + lea(mem(r12), rcx) // rcx = c_iijj; + lea(mem(r14), rax) // rax = a_ii; + lea(mem(rdx), rbx) // rbx = b_jj; + + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c + lea(mem(r8, r8, 4), rbp) // rbp = 5*rs_a + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4_BLOCK2) // if i == 0, jump to code that + // contains the k_iter4 loop. + + label(.DLOOPKITER16_BLOCK2) // MAIN LOOP + + // ---------------------------------- iteration 0 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + // ---------------------------------- iteration 1 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + // ---------------------------------- iteration 2 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + // ---------------------------------- iteration 3 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16_BLOCK2) // iterate again if i != 0. + + label(.DCONSIDKITER4_BLOCK2) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1_BLOCK2) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + label(.DLOOPKITER4_BLOCK2) // EDGE LOOP (ymm) + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4_BLOCK2) // iterate again if i != 0. + + label(.DCONSIDKLEFT1_BLOCK2) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM_BLOCK2) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + label(.DLOOPKLEFT1_BLOCK2) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax ), xmm0) + vmovsd(mem(rax, r8, 1), xmm1) + vmovsd(mem(rax, r8, 2), xmm2) + add(imm(1*8), rax) // a += 1*cs_a = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1_BLOCK2) // iterate again if i != 0. + + label(.DPOSTACCUM_BLOCK2) + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + // xmm4[0] = sum(ymm4); xmm4[1] = sum(ymm7) + // xmm4[2] = sum(ymm10); xmm4[3] = sum(ymm13) + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + // xmm5[0] = sum(ymm5); xmm5[1] = sum(ymm8) + // xmm5[2] = sum(ymm11); xmm5[3] = sum(ymm14) + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + // xmm6[0] = sum(ymm6); xmm6[1] = sum(ymm9) + // xmm6[2] = sum(ymm12); xmm6[3] = sum(ymm15) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO_BLOCK2) // if ZF = 1, jump to beta == 0 case + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + + jmp(.DDONE_BLOCK2) // jump to end. + + label(.DBETAZERO_BLOCK2) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx)) + + label(.DDONE_BLOCK2) + + lea(mem(r12, rdi, 2), r12) // + lea(mem(r12, rdi, 1), r12) // c_ii = r12 += 3*rs_c + + lea(mem(r14, r8, 2), r14) // + lea(mem(r14, r8, 1), r14) // a_ii = r14 += 3*rs_a + + dec(r9) // ii -= 1; + + add(imm(4), r15) // jj += 4; + +// ----------------------- Block 3 + mov(var(a), r14) // load address of a + mov(var(b), rdx) // load address of b + mov(var(c), r12) // load address of c + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(imm(1*8), rsi) // rsi *= cs_c*sizeof(double) = 1*8 + lea(mem(r12, rsi, 1), r12) // r12 = c + 4*jj*cs_c; + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(r11, rsi) // rsi *= cs_b; + lea(mem(rdx, rsi, 1), rdx) // rbx = b + 4*jj*cs_b; + + mov(var(m_iter), r9) // ii = m_iter; + + vxorpd(ymm4, ymm4, ymm4) + vmovapd( ymm4, ymm5) + vmovapd( ymm4, ymm6) + vmovapd( ymm4, ymm7) + vmovapd( ymm4, ymm8) + vmovapd( ymm4, ymm9) + vmovapd( ymm4, ymm10) + vmovapd( ymm4, ymm11) + vmovapd( ymm4, ymm12) + vmovapd( ymm4, ymm13) + vmovapd( ymm4, ymm14) + vmovapd( ymm4, ymm15) + + lea(mem(r12), rcx) // rcx = c_iijj; + lea(mem(r14), rax) // rax = a_ii; + lea(mem(rdx), rbx) // rbx = b_jj; + + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c + lea(mem(r8, r8, 4), rbp) // rbp = 5*rs_a + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4_BLOCK3) // if i == 0, jump to code that + // contains the k_iter4 loop. + + label(.DLOOPKITER16_BLOCK3) // MAIN LOOP + + // ---------------------------------- iteration 0 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm2, ymm3, ymm15) + + // ---------------------------------- iteration 1 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm2, ymm3, ymm15) + + // ---------------------------------- iteration 2 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm2, ymm3, ymm15) + + // ---------------------------------- iteration 3 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm2, ymm3, ymm15) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16_BLOCK3) // iterate again if i != 0. + + label(.DCONSIDKITER4_BLOCK3) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1_BLOCK3) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + label(.DLOOPKITER4_BLOCK3) // EDGE LOOP (ymm) + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm2, ymm3, ymm15) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4_BLOCK3) // iterate again if i != 0. + + label(.DCONSIDKLEFT1_BLOCK3) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM_BLOCK3) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + label(.DLOOPKLEFT1_BLOCK3) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax, r8, 2), xmm2) + add(imm(1*8), rax) // a += 1*cs_a = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm2, ymm3, ymm15) + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1_BLOCK3) // iterate again if i != 0. + + label(.DPOSTACCUM_BLOCK3) + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + // xmm6[0] = sum(ymm6); xmm6[1] = sum(ymm9) + // xmm6[2] = sum(ymm12); xmm6[3] = sum(ymm15) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm6, ymm6) + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO_BLOCK3) // if ZF = 1, jump to beta == 0 case + + + add(rdi, rcx) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovlpd(xmm6, mem(rcx)) + + jmp(.DDONE_BLOCK3) // jump to end. + + label(.DBETAZERO_BLOCK3) + + + add(rdi, rcx) + add(rdi, rcx) + vmovlpd(xmm6, mem(rcx)) + + label(.DDONE_BLOCK3) + + lea(mem(r12, rdi, 2), r12) // + lea(mem(r12, rdi, 1), r12) // c_ii = r12 += 3*rs_c + + lea(mem(r14, r8, 2), r14) // + lea(mem(r14, r8, 1), r14) // a_ii = r14 += 3*rs_a + + dec(r9) // ii -= 1; + +// ----------------------- Block 4 + vxorpd(ymm4, ymm4, ymm4) + vmovapd( ymm4, ymm5) + vmovapd( ymm4, ymm6) + vmovapd( ymm4, ymm7) + vmovapd( ymm4, ymm8) + vmovapd( ymm4, ymm9) + vmovapd( ymm4, ymm10) + vmovapd( ymm4, ymm11) + vmovapd( ymm4, ymm12) + vmovapd( ymm4, ymm13) + vmovapd( ymm4, ymm14) + vmovapd( ymm4, ymm15) + + lea(mem(r12), rcx) // rcx = c_iijj; + lea(mem(r14), rax) // rax = a_ii; + lea(mem(rdx), rbx) // rbx = b_jj; + + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c + lea(mem(r8, r8, 4), rbp) // rbp = 5*rs_a + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4_BLOCK4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + label(.DLOOPKITER16_BLOCK4) // MAIN LOOP + + // ---------------------------------- iteration 0 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + // ---------------------------------- iteration 1 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + // ---------------------------------- iteration 2 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + // ---------------------------------- iteration 3 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16_BLOCK4) // iterate again if i != 0. + + label(.DCONSIDKITER4_BLOCK4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1_BLOCK4) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + label(.DLOOPKITER4_BLOCK4) // EDGE LOOP (ymm) + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4_BLOCK4) // iterate again if i != 0. + + label(.DCONSIDKLEFT1_BLOCK4) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM_BLOCK4) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + label(.DLOOPKLEFT1_BLOCK4) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax ), xmm0) + vmovsd(mem(rax, r8, 1), xmm1) + vmovsd(mem(rax, r8, 2), xmm2) + add(imm(1*8), rax) // a += 1*cs_a = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1_BLOCK4) // iterate again if i != 0. + + label(.DPOSTACCUM_BLOCK4) + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + // xmm4[0] = sum(ymm4); xmm4[1] = sum(ymm7) + // xmm4[2] = sum(ymm10); xmm4[3] = sum(ymm13) + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + // xmm5[0] = sum(ymm5); xmm5[1] = sum(ymm8) + // xmm5[2] = sum(ymm11); xmm5[3] = sum(ymm14) + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + // xmm6[0] = sum(ymm6); xmm6[1] = sum(ymm9) + // xmm6[2] = sum(ymm12); xmm6[3] = sum(ymm15) + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO_BLOCK4) // if ZF = 1, jump to beta == 0 case + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(xmm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(xmm5, mem(rcx)) + vextractf128(imm(1), ymm5, xmm1 ) + vmovlpd(xmm1, mem(rcx, 2*8)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + + jmp(.DDONE_BLOCK4) // jump to end. + + label(.DBETAZERO_BLOCK4) + + + vmovupd(xmm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm5, mem(rcx)) + vextractf128(imm(1), ymm5, xmm1 ) + vmovlpd(xmm1, mem(rcx, 2*8)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx)) + + label(.DDONE_BLOCK4) + + vzeroupper() + + end_asm( + : // output operands (none) + : // input operands + [m_iter] "m" (m_iter), + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7); +} + +/* + +Following kernel computes the 6x8 block for the Lower vairant(L) of gemmt where +m_offset in 24x24 block is 18 and n_offset is 16(18x16) +(18x16)_L + + +the region marked with 'x' is computed by following kernel +the region marked with '-' is not computed + + <-- n_off_24 -- > + 16 17 18 19 20 21 22 23 + +ā†‘ 18 x x x - - - - - +| 19 x x x x - - - - +m 20 x x x x x - - - +off 21 x x x x x x - - +24 22 x x x x x x x - +| 23 x x x x x x x x +ā†“ + + +*/ +void bli_dgemmsup_rd_haswell_asm_6x8m_18x16_L + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_7); + + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t m_iter = m0 / 3; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + begin_asm() + + mov(var(rs_a), r8) // load rs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + mov(var(cs_b), r11) // load cs_b + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a + + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + mov(imm(0), r15) // jj = 0; + +// ----------------------- Block 1 + + mov(var(a), r14) // load address of a + mov(var(b), rdx) // load address of b + mov(var(c), r12) // load address of c + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(imm(1*8), rsi) // rsi *= cs_c*sizeof(double) = 1*8 + lea(mem(r12, rsi, 1), r12) // r12 = c + 4*jj*cs_c; + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(r11, rsi) // rsi *= cs_b; + lea(mem(rdx, rsi, 1), rdx) // rbx = b + 4*jj*cs_b; + + mov(var(m_iter), r9) // ii = m_iter; + + vxorpd(ymm4, ymm4, ymm4) + vmovapd( ymm4, ymm5) + vmovapd( ymm4, ymm6) + vmovapd( ymm4, ymm7) + vmovapd( ymm4, ymm8) + vmovapd( ymm4, ymm9) + vmovapd( ymm4, ymm10) + vmovapd( ymm4, ymm11) + vmovapd( ymm4, ymm12) + vmovapd( ymm4, ymm13) + vmovapd( ymm4, ymm14) + vmovapd( ymm4, ymm15) + + lea(mem(r12), rcx) // rcx = c_iijj; + lea(mem(r14), rax) // rax = a_ii; + lea(mem(rdx), rbx) // rbx = b_jj; + + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c + lea(mem(r8, r8, 4), rbp) // rbp = 5*rs_a + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4_BLOCK1) // if i == 0, jump to code that + // contains the k_iter4 loop. + + label(.DLOOPKITER16_BLOCK1) // MAIN LOOP + + // ---------------------------------- iteration 0 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + // ---------------------------------- iteration 1 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + // ---------------------------------- iteration 2 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + // ---------------------------------- iteration 3 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16_BLOCK1) // iterate again if i != 0. + + label(.DCONSIDKITER4_BLOCK1) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1_BLOCK1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + label(.DLOOPKITER4_BLOCK1) // EDGE LOOP (ymm) + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4_BLOCK1) // iterate again if i != 0. + + label(.DCONSIDKLEFT1_BLOCK1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM_BLOCK1) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + label(.DLOOPKLEFT1_BLOCK1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax ), xmm0) + vmovsd(mem(rax, r8, 1), xmm1) + vmovsd(mem(rax, r8, 2), xmm2) + add(imm(1*8), rax) // a += 1*cs_a = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1_BLOCK1) // iterate again if i != 0. + + label(.DPOSTACCUM_BLOCK1) + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + // xmm4[0] = sum(ymm4); xmm4[1] = sum(ymm7) + // xmm4[2] = sum(ymm10); xmm4[3] = sum(ymm13) + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + // xmm5[0] = sum(ymm5); xmm5[1] = sum(ymm8) + // xmm5[2] = sum(ymm11); xmm5[3] = sum(ymm14) + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + // xmm6[0] = sum(ymm6); xmm6[1] = sum(ymm9) + // xmm6[2] = sum(ymm12); xmm6[3] = sum(ymm15) + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO_BLOCK1) // if ZF = 1, jump to beta == 0 case + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vextractf128(imm(1), ymm4, xmm1 ) + vmovupd(xmm4, mem(rcx)) + vmovlpd(xmm1, mem(rcx, 2*8)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + + jmp(.DDONE_BLOCK1) // jump to end. + + label(.DBETAZERO_BLOCK1) + + + vextractf128(imm(1), ymm4, xmm1 ) + vmovupd(xmm4, mem(rcx)) + vmovlpd(xmm1, mem(rcx, 2*8)) + add(rdi, rcx) + + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx)) + + label(.DDONE_BLOCK1) + + lea(mem(r12, rdi, 2), r12) // + lea(mem(r12, rdi, 1), r12) // c_ii = r12 += 3*rs_c + + lea(mem(r14, r8, 2), r14) // + lea(mem(r14, r8, 1), r14) // a_ii = r14 += 3*rs_a + + dec(r9) // ii -= 1; + +// ----------------------- Block 2 + vxorpd(ymm4, ymm4, ymm4) + vmovapd( ymm4, ymm5) + vmovapd( ymm4, ymm6) + vmovapd( ymm4, ymm7) + vmovapd( ymm4, ymm8) + vmovapd( ymm4, ymm9) + vmovapd( ymm4, ymm10) + vmovapd( ymm4, ymm11) + vmovapd( ymm4, ymm12) + vmovapd( ymm4, ymm13) + vmovapd( ymm4, ymm14) + vmovapd( ymm4, ymm15) + + lea(mem(r12), rcx) // rcx = c_iijj; + lea(mem(r14), rax) // rax = a_ii; + lea(mem(rdx), rbx) // rbx = b_jj; + + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c + lea(mem(r8, r8, 4), rbp) // rbp = 5*rs_a + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4_BLOCK2) // if i == 0, jump to code that + // contains the k_iter4 loop. + + label(.DLOOPKITER16_BLOCK2) // MAIN LOOP + + // ---------------------------------- iteration 0 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + // ---------------------------------- iteration 1 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + // ---------------------------------- iteration 2 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + // ---------------------------------- iteration 3 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16_BLOCK2) // iterate again if i != 0. + + label(.DCONSIDKITER4_BLOCK2) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1_BLOCK2) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + label(.DLOOPKITER4_BLOCK2) // EDGE LOOP (ymm) + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4_BLOCK2) // iterate again if i != 0. + + label(.DCONSIDKLEFT1_BLOCK2) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM_BLOCK2) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + label(.DLOOPKLEFT1_BLOCK2) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax ), xmm0) + vmovsd(mem(rax, r8, 1), xmm1) + vmovsd(mem(rax, r8, 2), xmm2) + add(imm(1*8), rax) // a += 1*cs_a = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1_BLOCK2) // iterate again if i != 0. + + label(.DPOSTACCUM_BLOCK2) + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + // xmm4[0] = sum(ymm4); xmm4[1] = sum(ymm7) + // xmm4[2] = sum(ymm10); xmm4[3] = sum(ymm13) + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + // xmm5[0] = sum(ymm5); xmm5[1] = sum(ymm8) + // xmm5[2] = sum(ymm11); xmm5[3] = sum(ymm14) + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + // xmm6[0] = sum(ymm6); xmm6[1] = sum(ymm9) + // xmm6[2] = sum(ymm12); xmm6[3] = sum(ymm15) + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO_BLOCK2) // if ZF = 1, jump to beta == 0 case + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + + jmp(.DDONE_BLOCK2) // jump to end. + + label(.DBETAZERO_BLOCK2) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx)) + + label(.DDONE_BLOCK2) + + lea(mem(r12, rdi, 2), r12) // + lea(mem(r12, rdi, 1), r12) // c_ii = r12 += 3*rs_c + + lea(mem(r14, r8, 2), r14) // + lea(mem(r14, r8, 1), r14) // a_ii = r14 += 3*rs_a + + dec(r9) // ii -= 1; + + add(imm(4), r15) // jj += 4; + +// ----------------------- Block 3 + mov(var(a), r14) // load address of a + mov(var(b), rdx) // load address of b + mov(var(c), r12) // load address of c + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(imm(1*8), rsi) // rsi *= cs_c*sizeof(double) = 1*8 + lea(mem(r12, rsi, 1), r12) // r12 = c + 4*jj*cs_c; + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(r11, rsi) // rsi *= cs_b; + lea(mem(rdx, rsi, 1), rdx) // rbx = b + 4*jj*cs_b; + + mov(var(m_iter), r9) // ii = m_iter; + + vxorpd(ymm4, ymm4, ymm4) + vmovapd( ymm4, ymm5) + vmovapd( ymm4, ymm6) + vmovapd( ymm4, ymm7) + vmovapd( ymm4, ymm8) + vmovapd( ymm4, ymm9) + vmovapd( ymm4, ymm10) + vmovapd( ymm4, ymm11) + vmovapd( ymm4, ymm12) + vmovapd( ymm4, ymm13) + vmovapd( ymm4, ymm14) + vmovapd( ymm4, ymm15) + + lea(mem(r12), rcx) // rcx = c_iijj; + lea(mem(r14), rax) // rax = a_ii; + lea(mem(rdx), rbx) // rbx = b_jj; + + + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c lea(mem(r8, r8, 4), rbp) // rbp = 5*rs_a - + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4_BLOCK3) // if i == 0, jump to code that + // contains the k_iter4 loop. + + label(.DLOOPKITER16_BLOCK3) // MAIN LOOP + + // ---------------------------------- iteration 0 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm2, ymm3, ymm15) + + // ---------------------------------- iteration 1 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm2, ymm3, ymm15) + + // ---------------------------------- iteration 2 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm2, ymm3, ymm15) + + // ---------------------------------- iteration 3 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm2, ymm3, ymm15) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16_BLOCK3) // iterate again if i != 0. + + label(.DCONSIDKITER4_BLOCK3) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1_BLOCK3) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + label(.DLOOPKITER4_BLOCK3) // EDGE LOOP (ymm) + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm2, ymm3, ymm15) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4_BLOCK3) // iterate again if i != 0. + + label(.DCONSIDKLEFT1_BLOCK3) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM_BLOCK3) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + label(.DLOOPKLEFT1_BLOCK3) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax ), xmm0) + vmovsd(mem(rax, r8, 1), xmm1) + vmovsd(mem(rax, r8, 2), xmm2) + add(imm(1*8), rax) // a += 1*cs_a = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1_BLOCK3) // iterate again if i != 0. + + label(.DPOSTACCUM_BLOCK3) + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + // xmm6[0] = sum(ymm6); xmm6[1] = sum(ymm9) + // xmm6[2] = sum(ymm12); xmm6[3] = sum(ymm15) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm6, ymm6) + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO_BLOCK3) // if ZF = 1, jump to beta == 0 case + + + add(rdi, rcx) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovlpd(xmm6, mem(rcx)) + + jmp(.DDONE_BLOCK3) // jump to end. + + label(.DBETAZERO_BLOCK3) + + + add(rdi, rcx) + add(rdi, rcx) + vmovlpd(xmm6, mem(rcx)) + + label(.DDONE_BLOCK3) + + lea(mem(r12, rdi, 2), r12) // + lea(mem(r12, rdi, 1), r12) // c_ii = r12 += 3*rs_c + + lea(mem(r14, r8, 2), r14) // + lea(mem(r14, r8, 1), r14) // a_ii = r14 += 3*rs_a + + dec(r9) // ii -= 1; + +// ----------------------- Block 4 + vxorpd(ymm4, ymm4, ymm4) + vmovapd( ymm4, ymm5) + vmovapd( ymm4, ymm6) + vmovapd( ymm4, ymm7) + vmovapd( ymm4, ymm8) + vmovapd( ymm4, ymm9) + vmovapd( ymm4, ymm10) + vmovapd( ymm4, ymm11) + vmovapd( ymm4, ymm12) + vmovapd( ymm4, ymm13) + vmovapd( ymm4, ymm14) + vmovapd( ymm4, ymm15) + + lea(mem(r12), rcx) // rcx = c_iijj; + lea(mem(r14), rax) // rax = a_ii; + lea(mem(rdx), rbx) // rbx = b_jj; + + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c + lea(mem(r8, r8, 4), rbp) // rbp = 5*rs_a - mov(var(k_iter16), rsi) // i = k_iter16; test(rsi, rsi) // check i via logical AND. - je(.DCONSIDKITER4) // if i == 0, jump to code that + je(.DCONSIDKITER4_BLOCK4) // if i == 0, jump to code that // contains the k_iter4 loop. - - - label(.DLOOPKITER16) // MAIN LOOP - - + + label(.DLOOPKITER16_BLOCK4) // MAIN LOOP + // ---------------------------------- iteration 0 -#if 1 prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a -#endif vmovupd(mem(rax ), ymm0) vmovupd(mem(rax, r8, 1), ymm1) @@ -306,7 +12536,6 @@ void bli_dgemmsup_rd_haswell_asm_6x8m vfmadd231pd(ymm1, ymm3, ymm14) vfmadd231pd(ymm2, ymm3, ymm15) - // ---------------------------------- iteration 1 vmovupd(mem(rax ), ymm0) @@ -335,14 +12564,11 @@ void bli_dgemmsup_rd_haswell_asm_6x8m vfmadd231pd(ymm1, ymm3, ymm14) vfmadd231pd(ymm2, ymm3, ymm15) - // ---------------------------------- iteration 2 - -#if 1 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a -#endif vmovupd(mem(rax ), ymm0) vmovupd(mem(rax, r8, 1), ymm1) @@ -370,7 +12596,6 @@ void bli_dgemmsup_rd_haswell_asm_6x8m vfmadd231pd(ymm1, ymm3, ymm14) vfmadd231pd(ymm2, ymm3, ymm15) - // ---------------------------------- iteration 3 vmovupd(mem(rax ), ymm0) @@ -399,38 +12624,28 @@ void bli_dgemmsup_rd_haswell_asm_6x8m vfmadd231pd(ymm1, ymm3, ymm14) vfmadd231pd(ymm2, ymm3, ymm15) - - dec(rsi) // i -= 1; - jne(.DLOOPKITER16) // iterate again if i != 0. - - - - - - - label(.DCONSIDKITER4) - + jne(.DLOOPKITER16_BLOCK4) // iterate again if i != 0. + + label(.DCONSIDKITER4_BLOCK4) + mov(var(k_iter4), rsi) // i = k_iter4; test(rsi, rsi) // check i via logical AND. - je(.DCONSIDKLEFT1) // if i == 0, jump to code that + je(.DCONSIDKLEFT1_BLOCK4) // if i == 0, jump to code that // considers k_left1 loop. // else, we prepare to enter k_iter4 loop. - - - label(.DLOOPKITER4) // EDGE LOOP (ymm) - -#if 1 + + label(.DLOOPKITER4_BLOCK4) // EDGE LOOP (ymm) + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a -#endif vmovupd(mem(rax ), ymm0) vmovupd(mem(rax, r8, 1), ymm1) vmovupd(mem(rax, r8, 2), ymm2) add(imm(4*8), rax) // a += 4*cs_a = 4*8; - + vmovupd(mem(rbx ), ymm3) vfmadd231pd(ymm0, ymm3, ymm4) vfmadd231pd(ymm1, ymm3, ymm5) @@ -452,34 +12667,27 @@ void bli_dgemmsup_rd_haswell_asm_6x8m vfmadd231pd(ymm1, ymm3, ymm14) vfmadd231pd(ymm2, ymm3, ymm15) - dec(rsi) // i -= 1; - jne(.DLOOPKITER4) // iterate again if i != 0. - - - + jne(.DLOOPKITER4_BLOCK4) // iterate again if i != 0. + + label(.DCONSIDKLEFT1_BLOCK4) - label(.DCONSIDKLEFT1) - mov(var(k_left1), rsi) // i = k_left1; test(rsi, rsi) // check i via logical AND. - je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + je(.DPOSTACCUM_BLOCK4) // if i == 0, we're done; jump to end. // else, we prepare to enter k_left1 loop. - - - - label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + label(.DLOOPKLEFT1_BLOCK4) // EDGE LOOP (scalar) // NOTE: We must use ymm registers here bc // using the xmm registers would zero out the // high bits of the destination registers, // which would destory intermediate results. - + vmovsd(mem(rax ), xmm0) vmovsd(mem(rax, r8, 1), xmm1) vmovsd(mem(rax, r8, 2), xmm2) add(imm(1*8), rax) // a += 1*cs_a = 1*8; - + vmovsd(mem(rbx ), xmm3) vfmadd231pd(ymm0, ymm3, ymm4) vfmadd231pd(ymm1, ymm3, ymm5) @@ -501,22 +12709,15 @@ void bli_dgemmsup_rd_haswell_asm_6x8m vfmadd231pd(ymm1, ymm3, ymm14) vfmadd231pd(ymm2, ymm3, ymm15) - dec(rsi) // i -= 1; - jne(.DLOOPKLEFT1) // iterate again if i != 0. - - - - - + jne(.DLOOPKLEFT1_BLOCK4) // iterate again if i != 0. + label(.DPOSTACCUM_BLOCK4) - label(.DPOSTACCUM) - - // ymm4 ymm7 ymm10 ymm13 + // ymm4 ymm7 ymm10 ymm13 // ymm5 ymm8 ymm11 ymm14 // ymm6 ymm9 ymm12 ymm15 - + vhaddpd( ymm7, ymm4, ymm0 ) vextractf128(imm(1), ymm0, xmm1 ) vaddpd( xmm0, xmm1, xmm0 ) @@ -541,7 +12742,6 @@ void bli_dgemmsup_rd_haswell_asm_6x8m // xmm5[0] = sum(ymm5); xmm5[1] = sum(ymm8) // xmm5[2] = sum(ymm11); xmm5[3] = sum(ymm14) - vhaddpd( ymm9, ymm6, ymm0 ) vextractf128(imm(1), ymm0, xmm1 ) vaddpd( xmm0, xmm1, xmm0 ) @@ -554,107 +12754,54 @@ void bli_dgemmsup_rd_haswell_asm_6x8m // xmm6[0] = sum(ymm6); xmm6[1] = sum(ymm9) // xmm6[2] = sum(ymm12); xmm6[3] = sum(ymm15) - - - - //mov(var(rs_c), rdi) // load rs_c - //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) - mov(var(alpha), rax) // load address of alpha mov(var(beta), rbx) // load address of beta vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate - + vmulpd(ymm0, ymm4, ymm4) // scale by alpha vmulpd(ymm0, ymm5, ymm5) vmulpd(ymm0, ymm6, ymm6) - - - - - - - //mov(var(cs_c), rsi) // load cs_c - //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) - - - + // now avoid loading C if beta == 0 - + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. vucomisd(xmm0, xmm3) // set ZF if beta == 0. - je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case - + je(.DBETAZERO_BLOCK4) // if ZF = 1, jump to beta == 0 case + - - label(.DROWSTORED) - - vfmadd231pd(mem(rcx), ymm3, ymm4) - vmovupd(ymm4, mem(rcx)) + vmovupd(xmm4, mem(rcx)) add(rdi, rcx) - + vfmadd231pd(mem(rcx), ymm3, ymm5) - vmovupd(ymm5, mem(rcx)) + vmovupd(xmm5, mem(rcx)) + vextractf128(imm(1), ymm5, xmm1 ) + vmovlpd(xmm1, mem(rcx, 2*8)) add(rdi, rcx) - - vfmadd231pd(mem(rcx), ymm3, ymm6) - vmovupd(ymm6, mem(rcx)) - //add(rdi, rcx) - - - - jmp(.DDONE) // jump to end. - - - - - label(.DBETAZERO) - - - label(.DROWSTORBZ) - - - vmovupd(ymm4, mem(rcx)) - add(rdi, rcx) - - vmovupd(ymm5, mem(rcx)) - add(rdi, rcx) - + vfmadd231pd(mem(rcx), ymm3, ymm6) vmovupd(ymm6, mem(rcx)) - //add(rdi, rcx) - - - - - label(.DDONE) - - - - - lea(mem(r12, rdi, 2), r12) // - lea(mem(r12, rdi, 1), r12) // c_ii = r12 += 3*rs_c - - lea(mem(r14, r8, 2), r14) // - lea(mem(r14, r8, 1), r14) // a_ii = r14 += 3*rs_a - dec(r9) // ii -= 1; - jne(.DLOOP3X4I) // iterate again if ii != 0. + jmp(.DDONE_BLOCK4) // jump to end. + label(.DBETAZERO_BLOCK4) + vmovupd(xmm4, mem(rcx)) + add(rdi, rcx) - add(imm(4), r15) // jj += 4; - cmp(imm(8), r15) // compare jj to 8 - jl(.DLOOP3X4J) // if jj < 8, jump to beginning - // of jj loop; otherwise, loop ends. + vmovupd(xmm5, mem(rcx)) + vextractf128(imm(1), ymm5, xmm1 ) + vmovlpd(xmm1, mem(rcx, 2*8)) + add(rdi, rcx) + vmovupd(ymm6, mem(rcx)) - label(.DRETURN) + label(.DDONE_BLOCK4) - + vzeroupper() end_asm( : // output operands (none) @@ -686,42 +12833,6 @@ void bli_dgemmsup_rd_haswell_asm_6x8m "memory" ) - consider_edge_cases: - - // Handle edge cases in the m dimension, if they exist. - if ( m_left ) - { - const dim_t nr_cur = 8; - const dim_t i_edge = m0 - ( dim_t )m_left; - - double* restrict cij = c + i_edge*rs_c; - double* restrict bj = b; - double* restrict ai = a + i_edge*rs_a; - - if ( 2 == m_left ) - { - const dim_t mr_cur = 2; - - bli_dgemmsup_rd_haswell_asm_2x8 - ( - conja, conjb, mr_cur, nr_cur, k0, - alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, - beta, cij, rs_c0, cs_c0, data, cntx - ); - //cij += mr_cur*rs_c0; ai += mr_cur*rs_a0; m_left -= mr_cur; - } - if ( 1 == m_left ) - { - const dim_t mr_cur = 1; - - bli_dgemmsup_rd_haswell_asm_1x8 - ( - conja, conjb, mr_cur, nr_cur, k0, - alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, - beta, cij, rs_c0, cs_c0, data, cntx - ); - } - } AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7); } @@ -840,7 +12951,7 @@ void bli_dgemmsup_rd_haswell_asm_6x4m prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c #endif - lea(mem(r8, r8, 4), rbp) // rbp = 5*rs_a + lea(mem(r8, r8, 4), rcx) // rcx = 5*rs_a @@ -859,7 +12970,7 @@ void bli_dgemmsup_rd_haswell_asm_6x4m #if 1 prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a - prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + prefetch(0, mem(rax, rcx, 1, 0*8)) // prefetch rax + 5*rs_a #endif vmovupd(mem(rax ), ymm0) @@ -923,7 +13034,7 @@ void bli_dgemmsup_rd_haswell_asm_6x4m #if 1 prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a - prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + prefetch(0, mem(rax, rcx, 1, 0*8)) // prefetch rax + 5*rs_a #endif vmovupd(mem(rax ), ymm0) @@ -1005,7 +13116,7 @@ void bli_dgemmsup_rd_haswell_asm_6x4m #if 1 prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a - prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + prefetch(0, mem(rax, rcx, 1, 0*8)) // prefetch rax + 5*rs_a #endif vmovupd(mem(rax ), ymm0) @@ -1141,6 +13252,7 @@ void bli_dgemmsup_rd_haswell_asm_6x4m //mov(var(rs_c), rdi) // load rs_c //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + lea(mem(r12), rcx) // rcx = c + 3*ii*rs_c; mov(var(alpha), rax) // load address of alpha mov(var(beta), rbx) // load address of beta vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate @@ -1228,7 +13340,7 @@ void bli_dgemmsup_rd_haswell_asm_6x4m label(.DRETURN) - + vzeroupper() end_asm( : // output operands (none) @@ -1251,7 +13363,7 @@ void bli_dgemmsup_rd_haswell_asm_6x4m [a_next] "m" (a_next), [b_next] "m" (b_next)*/ : // register clobber list - "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", "xmm0", "xmm1", "xmm2", "xmm3", "xmm4", "xmm5", "xmm6", "xmm7", @@ -1844,7 +13956,7 @@ void bli_dgemmsup_rd_haswell_asm_6x2m label(.DRETURN) - + vzeroupper() end_asm( : // output operands (none) diff --git a/kernels/haswell/3/sup/bli_gemmsup_rv_haswell_asm_d6x8m.c b/kernels/haswell/3/sup/bli_gemmsup_rv_haswell_asm_d6x8m.c index c7a95d65f1..8ac3612bdf 100644 --- a/kernels/haswell/3/sup/bli_gemmsup_rv_haswell_asm_d6x8m.c +++ b/kernels/haswell/3/sup/bli_gemmsup_rv_haswell_asm_d6x8m.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2019, Advanced Micro Devices, Inc. + Copyright (C) 2019 - 22, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -40,20 +40,20 @@ /* rrr: - -------- ------ -------- - -------- ------ -------- - -------- += ------ ... -------- - -------- ------ -------- - -------- ------ : - -------- ------ : + -------- ------ -------- + -------- ------ -------- + -------- += ------ ... -------- + -------- ------ -------- + -------- ------ : + -------- ------ : rcr: - -------- | | | | -------- - -------- | | | | -------- - -------- += | | | | ... -------- - -------- | | | | -------- - -------- | | | | : - -------- | | | | : + -------- | | | | -------- + -------- | | | | -------- + -------- += | | | | ... -------- + -------- | | | | -------- + -------- | | | | : + -------- | | | | : Assumptions: - B is row-stored; @@ -69,12 +69,12 @@ cost of the in-register transpose). crr: - | | | | | | | | ------ -------- - | | | | | | | | ------ -------- - | | | | | | | | += ------ ... -------- - | | | | | | | | ------ -------- - | | | | | | | | ------ : - | | | | | | | | ------ : + | | | | | | | | ------ -------- + | | | | | | | | ------ -------- + | | | | | | | | += ------ ... -------- + | | | | | | | | ------ -------- + | | | | | | | | ------ : + | | | | | | | | ------ : */ // Prototype reference microkernels. @@ -226,15 +226,15 @@ void bli_dgemmsup_rv_haswell_asm_6x8m // ------------------------------------------------------------------------- begin_asm() - + //vzeroall() // zero all xmm/ymm registers. - + mov(var(a), r14) // load address of a. mov(var(rs_a), r8) // load rs_a mov(var(cs_a), r9) // load cs_a lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) - + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a @@ -274,17 +274,17 @@ void bli_dgemmsup_rv_haswell_asm_6x8m // a latency of 1 cycle, while vzeroall // has a latency of 12 cycles. vxorpd(ymm4, ymm4, ymm4) - vxorpd(ymm5, ymm5, ymm5) - vxorpd(ymm6, ymm6, ymm6) - vxorpd(ymm7, ymm7, ymm7) - vxorpd(ymm8, ymm8, ymm8) - vxorpd(ymm9, ymm9, ymm9) - vxorpd(ymm10, ymm10, ymm10) - vxorpd(ymm11, ymm11, ymm11) - vxorpd(ymm12, ymm12, ymm12) - vxorpd(ymm13, ymm13, ymm13) - vxorpd(ymm14, ymm14, ymm14) - vxorpd(ymm15, ymm15, ymm15) + vmovapd( ymm4, ymm5) + vmovapd( ymm4, ymm6) + vmovapd( ymm4, ymm7) + vmovapd( ymm4, ymm8) + vmovapd( ymm4, ymm9) + vmovapd( ymm4, ymm10) + vmovapd( ymm4, ymm11) + vmovapd( ymm4, ymm12) + vmovapd( ymm4, ymm13) + vmovapd( ymm4, ymm14) + vmovapd( ymm4, ymm15) #endif mov(var(b), rbx) // load address of b. @@ -337,19 +337,19 @@ void bli_dgemmsup_rv_haswell_asm_6x8m lea(mem(rdx, r8, 2), rdx) // from next upanel of a. lea(mem(r9, r9, 2), rcx) // rcx = 3*cs_a; #endif - - - - + + + + mov(var(k_iter), rsi) // i = k_iter; test(rsi, rsi) // check i via logical AND. je(.DCONSIDKLEFT) // if i == 0, jump to code that // contains the k_left loop. - - + + label(.DLOOPKITER) // MAIN LOOP - - + + // ---------------------------------- iteration 0 #if 0 @@ -357,7 +357,7 @@ void bli_dgemmsup_rv_haswell_asm_6x8m #else prefetch(0, mem(rdx, 5*8)) #endif - + vmovupd(mem(rbx, 0*32), ymm0) vmovupd(mem(rbx, 1*32), ymm1) add(r10, rbx) // b += rs_b; @@ -368,14 +368,14 @@ void bli_dgemmsup_rv_haswell_asm_6x8m vfmadd231pd(ymm1, ymm2, ymm5) vfmadd231pd(ymm0, ymm3, ymm6) vfmadd231pd(ymm1, ymm3, ymm7) - + vbroadcastsd(mem(rax, r8, 2), ymm2) vbroadcastsd(mem(rax, r13, 1), ymm3) vfmadd231pd(ymm0, ymm2, ymm8) vfmadd231pd(ymm1, ymm2, ymm9) vfmadd231pd(ymm0, ymm3, ymm10) vfmadd231pd(ymm1, ymm3, ymm11) - + vbroadcastsd(mem(rax, r8, 4), ymm2) vbroadcastsd(mem(rax, r15, 1), ymm3) add(r9, rax) // a += cs_a; @@ -384,7 +384,7 @@ void bli_dgemmsup_rv_haswell_asm_6x8m vfmadd231pd(ymm0, ymm3, ymm14) vfmadd231pd(ymm1, ymm3, ymm15) - + // ---------------------------------- iteration 1 #if 0 @@ -403,14 +403,14 @@ void bli_dgemmsup_rv_haswell_asm_6x8m vfmadd231pd(ymm1, ymm2, ymm5) vfmadd231pd(ymm0, ymm3, ymm6) vfmadd231pd(ymm1, ymm3, ymm7) - + vbroadcastsd(mem(rax, r8, 2), ymm2) vbroadcastsd(mem(rax, r13, 1), ymm3) vfmadd231pd(ymm0, ymm2, ymm8) vfmadd231pd(ymm1, ymm2, ymm9) vfmadd231pd(ymm0, ymm3, ymm10) vfmadd231pd(ymm1, ymm3, ymm11) - + vbroadcastsd(mem(rax, r8, 4), ymm2) vbroadcastsd(mem(rax, r15, 1), ymm3) add(r9, rax) // a += cs_a; @@ -418,8 +418,8 @@ void bli_dgemmsup_rv_haswell_asm_6x8m vfmadd231pd(ymm1, ymm2, ymm13) vfmadd231pd(ymm0, ymm3, ymm14) vfmadd231pd(ymm1, ymm3, ymm15) - - + + // ---------------------------------- iteration 2 #if 0 @@ -427,7 +427,7 @@ void bli_dgemmsup_rv_haswell_asm_6x8m #else prefetch(0, mem(rdx, r9, 2, 5*8)) #endif - + vmovupd(mem(rbx, 0*32), ymm0) vmovupd(mem(rbx, 1*32), ymm1) add(r10, rbx) // b += rs_b; @@ -438,14 +438,14 @@ void bli_dgemmsup_rv_haswell_asm_6x8m vfmadd231pd(ymm1, ymm2, ymm5) vfmadd231pd(ymm0, ymm3, ymm6) vfmadd231pd(ymm1, ymm3, ymm7) - + vbroadcastsd(mem(rax, r8, 2), ymm2) vbroadcastsd(mem(rax, r13, 1), ymm3) vfmadd231pd(ymm0, ymm2, ymm8) vfmadd231pd(ymm1, ymm2, ymm9) vfmadd231pd(ymm0, ymm3, ymm10) vfmadd231pd(ymm1, ymm3, ymm11) - + vbroadcastsd(mem(rax, r8, 4), ymm2) vbroadcastsd(mem(rax, r15, 1), ymm3) add(r9, rax) // a += cs_a; @@ -453,7 +453,7 @@ void bli_dgemmsup_rv_haswell_asm_6x8m vfmadd231pd(ymm1, ymm2, ymm13) vfmadd231pd(ymm0, ymm3, ymm14) vfmadd231pd(ymm1, ymm3, ymm15) - + // ---------------------------------- iteration 3 @@ -474,14 +474,14 @@ void bli_dgemmsup_rv_haswell_asm_6x8m vfmadd231pd(ymm1, ymm2, ymm5) vfmadd231pd(ymm0, ymm3, ymm6) vfmadd231pd(ymm1, ymm3, ymm7) - + vbroadcastsd(mem(rax, r8, 2), ymm2) vbroadcastsd(mem(rax, r13, 1), ymm3) vfmadd231pd(ymm0, ymm2, ymm8) vfmadd231pd(ymm1, ymm2, ymm9) vfmadd231pd(ymm0, ymm3, ymm10) vfmadd231pd(ymm1, ymm3, ymm11) - + vbroadcastsd(mem(rax, r8, 4), ymm2) vbroadcastsd(mem(rax, r15, 1), ymm3) add(r9, rax) // a += cs_a; @@ -489,50 +489,50 @@ void bli_dgemmsup_rv_haswell_asm_6x8m vfmadd231pd(ymm1, ymm2, ymm13) vfmadd231pd(ymm0, ymm3, ymm14) vfmadd231pd(ymm1, ymm3, ymm15) - - - + + + dec(rsi) // i -= 1; jne(.DLOOPKITER) // iterate again if i != 0. - - - - - - + + + + + + label(.DCONSIDKLEFT) - + mov(var(k_left), rsi) // i = k_left; test(rsi, rsi) // check i via logical AND. je(.DPOSTACCUM) // if i == 0, we're done; jump to end. // else, we prepare to enter k_left loop. - - + + label(.DLOOPKLEFT) // EDGE LOOP #if 1 prefetch(0, mem(rdx, 5*8)) add(r9, rdx) #endif - + vmovupd(mem(rbx, 0*32), ymm0) vmovupd(mem(rbx, 1*32), ymm1) add(r10, rbx) // b += rs_b; - + vbroadcastsd(mem(rax ), ymm2) vbroadcastsd(mem(rax, r8, 1), ymm3) vfmadd231pd(ymm0, ymm2, ymm4) vfmadd231pd(ymm1, ymm2, ymm5) vfmadd231pd(ymm0, ymm3, ymm6) vfmadd231pd(ymm1, ymm3, ymm7) - + vbroadcastsd(mem(rax, r8, 2), ymm2) vbroadcastsd(mem(rax, r13, 1), ymm3) vfmadd231pd(ymm0, ymm2, ymm8) vfmadd231pd(ymm1, ymm2, ymm9) vfmadd231pd(ymm0, ymm3, ymm10) vfmadd231pd(ymm1, ymm3, ymm11) - + vbroadcastsd(mem(rax, r8, 4), ymm2) vbroadcastsd(mem(rax, r15, 1), ymm3) add(r9, rax) // a += cs_a; @@ -540,23 +540,23 @@ void bli_dgemmsup_rv_haswell_asm_6x8m vfmadd231pd(ymm1, ymm2, ymm13) vfmadd231pd(ymm0, ymm3, ymm14) vfmadd231pd(ymm1, ymm3, ymm15) - - + + dec(rsi) // i -= 1; jne(.DLOOPKLEFT) // iterate again if i != 0. - - - + + + label(.DPOSTACCUM) - - + + mov(r12, rcx) // reset rcx to current utile of c. mov(var(alpha), rax) // load address of alpha mov(var(beta), rbx) // load address of beta vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate - + vmulpd(ymm0, ymm4, ymm4) // scale by alpha vmulpd(ymm0, ymm5, ymm5) vmulpd(ymm0, ymm6, ymm6) @@ -569,24 +569,24 @@ void bli_dgemmsup_rv_haswell_asm_6x8m vmulpd(ymm0, ymm13, ymm13) vmulpd(ymm0, ymm14, ymm14) vmulpd(ymm0, ymm15, ymm15) - - - - - - + + + + + + mov(var(cs_c), rsi) // load cs_c lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) - + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; - - - + + + // now avoid loading C if beta == 0 - + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. vucomisd(xmm0, xmm3) // set ZF if beta == 0. je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case @@ -595,60 +595,60 @@ void bli_dgemmsup_rv_haswell_asm_6x8m cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. jz(.DCOLSTORED) // jump to column storage case - - + + label(.DROWSTORED) - - + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm4) vmovupd(ymm4, mem(rcx, 0*32)) vfmadd231pd(mem(rcx, 1*32), ymm3, ymm5) vmovupd(ymm5, mem(rcx, 1*32)) add(rdi, rcx) - - + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm6) vmovupd(ymm6, mem(rcx, 0*32)) vfmadd231pd(mem(rcx, 1*32), ymm3, ymm7) vmovupd(ymm7, mem(rcx, 1*32)) add(rdi, rcx) - - + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm8) vmovupd(ymm8, mem(rcx, 0*32)) vfmadd231pd(mem(rcx, 1*32), ymm3, ymm9) vmovupd(ymm9, mem(rcx, 1*32)) add(rdi, rcx) - - + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm10) vmovupd(ymm10, mem(rcx, 0*32)) vfmadd231pd(mem(rcx, 1*32), ymm3, ymm11) vmovupd(ymm11, mem(rcx, 1*32)) add(rdi, rcx) - - + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm12) vmovupd(ymm12, mem(rcx, 0*32)) vfmadd231pd(mem(rcx, 1*32), ymm3, ymm13) vmovupd(ymm13, mem(rcx, 1*32)) add(rdi, rcx) - - + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm14) vmovupd(ymm14, mem(rcx, 0*32)) vfmadd231pd(mem(rcx, 1*32), ymm3, ymm15) vmovupd(ymm15, mem(rcx, 1*32)) //add(rdi, rcx) - - + + jmp(.DDONE) // jump to end. @@ -735,51 +735,51 @@ void bli_dgemmsup_rv_haswell_asm_6x8m jmp(.DDONE) // jump to end. - - - - + + + + label(.DBETAZERO) - + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. jz(.DCOLSTORBZ) // jump to column storage case - + label(.DROWSTORBZ) - - + + vmovupd(ymm4, mem(rcx, 0*32)) vmovupd(ymm5, mem(rcx, 1*32)) add(rdi, rcx) - + vmovupd(ymm6, mem(rcx, 0*32)) vmovupd(ymm7, mem(rcx, 1*32)) add(rdi, rcx) - - + + vmovupd(ymm8, mem(rcx, 0*32)) vmovupd(ymm9, mem(rcx, 1*32)) add(rdi, rcx) - - + + vmovupd(ymm10, mem(rcx, 0*32)) vmovupd(ymm11, mem(rcx, 1*32)) add(rdi, rcx) - - + + vmovupd(ymm12, mem(rcx, 0*32)) vmovupd(ymm13, mem(rcx, 1*32)) add(rdi, rcx) - - + + vmovupd(ymm14, mem(rcx, 0*32)) vmovupd(ymm15, mem(rcx, 1*32)) //add(rdi, rcx) - - + + jmp(.DDONE) // jump to end. @@ -844,9 +844,9 @@ void bli_dgemmsup_rv_haswell_asm_6x8m //lea(mem(rdx, rsi, 4), rdx) - - - + + + label(.DDONE) @@ -867,8 +867,7 @@ void bli_dgemmsup_rv_haswell_asm_6x8m label(.DRETURN) - - + vzeroupper() end_asm( : // output operands (none) @@ -985,34 +984,115 @@ void bli_dgemmsup_rv_haswell_asm_6x8m AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7); } -void bli_dgemmsup_rv_haswell_asm_6x6m +/* +24x24 block + + 1 1 1 1 1 1 1 1 1 1 2 2 2 2 + 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 + |- - - - - - - -|- - - - - - - -| - - - - - - - -| +0 | | | | +1 | m_off_24 = 0 | | | +2 | n_off_24 = 0 | | | +3 | m_idx = 0 | | | +4 | n_idx = 0 | | | +5 |- - - - - - - -|- - - - - - - -|- - - - - - - - | +6 | | | | +7 | m_off_24 = 6 | m_off_24 = 6 | | +8 | n_off_24 = 0 | n_off_24 = 8 | | +9 | m_idx = 1 | m_idx = 1 | | +10 | n_idx = 0 | n_idx = 1 | | +11 |- - - - - - - -|- - - - - - - -|- - - - - - - - | +12 | | | | +13 | | m_off_24 = 12 | m_off_24 = 12 | +14 | | n_off_24 = 8 | n_off_24 = 16 | +15 | | m_idx = 2 | m_idx = 2 | +16 | | n_idx = 1 | n_idx = 2 | +17 |- - - - - - - -|- - - - - - - -|- - - - - - - - | +18 | | | | +19 | | | m_off_24 = 18 | +20 | | | n_off_24 = 16 | +21 | | | m_idx = 3 | +22 | | | n_idx = 2 | +23 |- - - - - - - -|- - - - - - - -|- - - - - - - - | +*/ + +#define PREFETCH_C() \ +\ + cmp(imm(8), rdi) \ + jz(.DCOLPFETCH) \ + label(.DROWPFETCH) \ + \ + lea(mem(r12, rdi, 2), rdx) \ + lea(mem(rdx, rdi, 1), rdx) \ + prefetch(0, mem(r12, 7*8)) \ + prefetch(0, mem(r12, rdi, 1, 7*8)) \ + prefetch(0, mem(r12, rdi, 2, 7*8)) \ + prefetch(0, mem(rdx, 7*8)) \ + prefetch(0, mem(rdx, rdi, 1, 7*8)) \ + prefetch(0, mem(rdx, rdi, 2, 7*8)) \ + \ + jmp(.DPOSTPFETCH) \ + label(.DCOLPFETCH) \ + \ + mov(var(cs_c), rsi) \ + lea(mem(, rsi, 8), rsi) \ + lea(mem(r12, rsi, 2), rdx) \ + lea(mem(rdx, rsi, 1), rdx) \ + prefetch(0, mem(r12, 5*8)) \ + prefetch(0, mem(r12, rsi, 1, 5*8)) \ + prefetch(0, mem(r12, rsi, 2, 5*8)) \ + prefetch(0, mem(rdx, 5*8)) \ + prefetch(0, mem(rdx, rsi, 1, 5*8)) \ + prefetch(0, mem(rdx, rsi, 2, 5*8)) \ + lea(mem(rdx, rsi, 2), rdx) \ + prefetch(0, mem(rdx, rsi, 1, 5*8)) \ + prefetch(0, mem(rdx, rsi, 2, 5*8)) \ + +/* + +Following kernel computes the 6x8 block for the Lower vairant(L) of gemmt where +m_offset in 24x24 block is 0 and n_offset is 0(0x0) +(0x0)_L + +the region marked with 'x' is computed by following kernel +the region marked with '-' is not computed + + <-- n_off_24 -- > + 0 1 2 3 4 5 6 7 + +ā†‘ 0 x - - - - - - - +| 1 x x - - - - - - +m 2 x x x - - - - - +off 3 x x x x - - - - +24 4 x x x x x - - - +| 5 x x x x x x - - +ā†“ + + +*/ +void bli_dgemmsup_rv_haswell_asm_6x8m_0x0_L ( conj_t conja, conj_t conjb, dim_t m0, dim_t n0, dim_t k0, - double* restrict alpha, - double* restrict a, inc_t rs_a0, inc_t cs_a0, - double* restrict b, inc_t rs_b0, inc_t cs_b0, - double* restrict beta, - double* restrict c, inc_t rs_c0, inc_t cs_c0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, auxinfo_t* restrict data, cntx_t* restrict cntx ) { AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_7); - //void* a_next = bli_auxinfo_next_a( data ); - //void* b_next = bli_auxinfo_next_b( data ); // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. uint64_t k_iter = k0 / 4; uint64_t k_left = k0 % 4; - uint64_t m_iter = m0 / 6; - uint64_t m_left = m0 % 6; - uint64_t rs_a = rs_a0; uint64_t cs_a = cs_a0; uint64_t rs_b = rs_b0; @@ -1021,31 +1101,24 @@ void bli_dgemmsup_rv_haswell_asm_6x6m uint64_t cs_c = cs_c0; // Query the panel stride of A and convert it to units of bytes. - uint64_t ps_a = bli_auxinfo_ps_a( data ); - uint64_t ps_a8 = ps_a * sizeof( double ); + uint64_t ps_a8 = bli_auxinfo_ps_a( data ) * sizeof( double ); - if ( m_iter == 0 ) goto consider_edge_cases; // ------------------------------------------------------------------------- begin_asm() - - //vzeroall() // zero all xmm/ymm registers. - + mov(var(a), r14) // load address of a. mov(var(rs_a), r8) // load rs_a mov(var(cs_a), r9) // load cs_a lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) - + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a - //mov(var(b), rbx) // load address of b. mov(var(rs_b), r10) // load rs_b - //mov(var(cs_b), r11) // load cs_b lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) - //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) // NOTE: We cannot pre-load elements of a or b // because it could eventually, in the last @@ -1057,204 +1130,7235 @@ void bli_dgemmsup_rv_haswell_asm_6x6m mov(var(rs_c), rdi) // load rs_c lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) - - // During preamble and loops: - // r12 = rcx = c - // r14 = rax = a - // read rbx from var(b) near beginning of loop - // r11 = m dim index ii - - mov(var(m_iter), r11) // ii = m_iter; - - label(.DLOOP6X8I) // LOOP OVER ii = [ m_iter ... 1 0 ] - - - -#if 0 - vzeroall() // zero all xmm/ymm registers. -#else + //for triangular kernels we can skip 1st loop around micro kernel // skylake can execute 3 vxorpd ipc with // a latency of 1 cycle, while vzeroall // has a latency of 12 cycles. - vxorpd(ymm1, ymm1, ymm1) // zero ymm1 since we only use the lower - vxorpd(ymm4, ymm4, ymm4) // half (xmm1), and nans/infs may slow us - vxorpd(ymm5, ymm5, ymm5) // down. - vxorpd(ymm6, ymm6, ymm6) - vxorpd(ymm7, ymm7, ymm7) - vxorpd(ymm8, ymm8, ymm8) - vxorpd(ymm9, ymm9, ymm9) - vxorpd(ymm10, ymm10, ymm10) - vxorpd(ymm11, ymm11, ymm11) - vxorpd(ymm12, ymm12, ymm12) - vxorpd(ymm13, ymm13, ymm13) - vxorpd(ymm14, ymm14, ymm14) - vxorpd(ymm15, ymm15, ymm15) -#endif + vxorpd(ymm4, ymm4, ymm4) + vmovapd( ymm4, ymm6) + vmovapd( ymm4, ymm8) + vmovapd( ymm4, ymm10) + vmovapd( ymm4, ymm12) + vmovapd( ymm4, ymm13) + vmovapd( ymm4, ymm14) + vmovapd( ymm4, ymm15) mov(var(b), rbx) // load address of b. - //mov(r12, rcx) // reset rcx to current utile of c. mov(r14, rax) // reset rax to current upanel of a. - - - cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. - jz(.DCOLPFETCH) // jump to column storage case - label(.DROWPFETCH) // row-stored prefetching on c - - lea(mem(r12, rdi, 2), rdx) // - lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; - prefetch(0, mem(r12, 5*8)) // prefetch c + 0*rs_c - prefetch(0, mem(r12, rdi, 1, 5*8)) // prefetch c + 1*rs_c - prefetch(0, mem(r12, rdi, 2, 5*8)) // prefetch c + 2*rs_c - prefetch(0, mem(rdx, 5*8)) // prefetch c + 3*rs_c - prefetch(0, mem(rdx, rdi, 1, 5*8)) // prefetch c + 4*rs_c - prefetch(0, mem(rdx, rdi, 2, 5*8)) // prefetch c + 5*rs_c - - jmp(.DPOSTPFETCH) // jump to end of prefetching c - label(.DCOLPFETCH) // column-stored prefetching c - - mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) - lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) - lea(mem(r12, rsi, 2), rdx) // - lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; - prefetch(0, mem(r12, 5*8)) // prefetch c + 0*cs_c - prefetch(0, mem(r12, rsi, 1, 5*8)) // prefetch c + 1*cs_c - prefetch(0, mem(r12, rsi, 2, 5*8)) // prefetch c + 2*cs_c - prefetch(0, mem(rdx, 5*8)) // prefetch c + 3*cs_c - prefetch(0, mem(rdx, rsi, 1, 5*8)) // prefetch c + 4*cs_c - prefetch(0, mem(rdx, rsi, 2, 5*8)) // prefetch c + 5*cs_c + PREFETCH_C() + lea(mem(rdx, rsi, 2), rdx) // rdx = c + 5*cs_c; label(.DPOSTPFETCH) // done prefetching c -#if 1 mov(var(ps_a8), rdx) // load ps_a8 lea(mem(rax, rdx, 1), rdx) // rdx = a + ps_a8 lea(mem(r9, r9, 2), rcx) // rcx = 3*cs_a; // use rcx, rdx for prefetching lines // from next upanel of a. -#else - lea(mem(rax, r8, 4), rdx) // use rdx for prefetching lines - lea(mem(rdx, r8, 2), rdx) // from next upanel of a. - lea(mem(r9, r9, 2), rcx) // rcx = 3*cs_a; -#endif - - - - + + + + mov(var(k_iter), rsi) // i = k_iter; test(rsi, rsi) // check i via logical AND. je(.DCONSIDKLEFT) // if i == 0, jump to code that // contains the k_left loop. - - + + // skip computation of ymm5, ymm7, ymm9, ymm11 and compute only half of ymm4, ymm6, ymm13, ymm15 label(.DLOOPKITER) // MAIN LOOP - - + + // ---------------------------------- iteration 0 -#if 0 - prefetch(0, mem(rdx, 5*8)) -#else prefetch(0, mem(rdx, 5*8)) -#endif - + vmovupd(mem(rbx, 0*32), ymm0) - vmovupd(mem(rbx, 1*32), xmm1) + vmovupd(mem(rbx, 1*32), ymm1) add(r10, rbx) // b += rs_b; vbroadcastsd(mem(rax ), ymm2) vbroadcastsd(mem(rax, r8, 1), ymm3) - vfmadd231pd(ymm0, ymm2, ymm4) - vfmadd231pd(ymm1, ymm2, ymm5) - vfmadd231pd(ymm0, ymm3, ymm6) - vfmadd231pd(ymm1, ymm3, ymm7) - + vfmadd231pd(xmm0, xmm2, xmm4) + vfmadd231pd(xmm0, xmm3, xmm6) + vbroadcastsd(mem(rax, r8, 2), ymm2) vbroadcastsd(mem(rax, r13, 1), ymm3) vfmadd231pd(ymm0, ymm2, ymm8) - vfmadd231pd(ymm1, ymm2, ymm9) vfmadd231pd(ymm0, ymm3, ymm10) - vfmadd231pd(ymm1, ymm3, ymm11) - + vbroadcastsd(mem(rax, r8, 4), ymm2) vbroadcastsd(mem(rax, r15, 1), ymm3) add(r9, rax) // a += cs_a; vfmadd231pd(ymm0, ymm2, ymm12) - vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(xmm1, xmm2, xmm13) vfmadd231pd(ymm0, ymm3, ymm14) - vfmadd231pd(ymm1, ymm3, ymm15) + vfmadd231pd(xmm1, xmm3, xmm15) + - // ---------------------------------- iteration 1 -#if 0 - prefetch(0, mem(rdx, 5*8)) -#else prefetch(0, mem(rdx, r9, 1, 5*8)) -#endif vmovupd(mem(rbx, 0*32), ymm0) - vmovupd(mem(rbx, 1*32), xmm1) + vmovupd(mem(rbx, 1*32), ymm1) add(r10, rbx) // b += rs_b; vbroadcastsd(mem(rax ), ymm2) vbroadcastsd(mem(rax, r8, 1), ymm3) - vfmadd231pd(ymm0, ymm2, ymm4) - vfmadd231pd(ymm1, ymm2, ymm5) - vfmadd231pd(ymm0, ymm3, ymm6) - vfmadd231pd(ymm1, ymm3, ymm7) - + vfmadd231pd(xmm0, xmm2, xmm4) + vfmadd231pd(xmm0, xmm3, xmm6) + vbroadcastsd(mem(rax, r8, 2), ymm2) vbroadcastsd(mem(rax, r13, 1), ymm3) vfmadd231pd(ymm0, ymm2, ymm8) - vfmadd231pd(ymm1, ymm2, ymm9) vfmadd231pd(ymm0, ymm3, ymm10) - vfmadd231pd(ymm1, ymm3, ymm11) - + vbroadcastsd(mem(rax, r8, 4), ymm2) vbroadcastsd(mem(rax, r15, 1), ymm3) add(r9, rax) // a += cs_a; vfmadd231pd(ymm0, ymm2, ymm12) - vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(xmm1, xmm2, xmm13) vfmadd231pd(ymm0, ymm3, ymm14) - vfmadd231pd(ymm1, ymm3, ymm15) - - + vfmadd231pd(xmm1, xmm3, xmm15) + + // ---------------------------------- iteration 2 -#if 0 - prefetch(0, mem(rdx, 5*8)) -#else prefetch(0, mem(rdx, r9, 2, 5*8)) -#endif - + vmovupd(mem(rbx, 0*32), ymm0) - vmovupd(mem(rbx, 1*32), xmm1) + vmovupd(mem(rbx, 1*32), ymm1) add(r10, rbx) // b += rs_b; vbroadcastsd(mem(rax ), ymm2) vbroadcastsd(mem(rax, r8, 1), ymm3) - vfmadd231pd(ymm0, ymm2, ymm4) - vfmadd231pd(ymm1, ymm2, ymm5) - vfmadd231pd(ymm0, ymm3, ymm6) - vfmadd231pd(ymm1, ymm3, ymm7) - + vfmadd231pd(xmm0, xmm2, xmm4) + vfmadd231pd(xmm0, xmm3, xmm6) + vbroadcastsd(mem(rax, r8, 2), ymm2) vbroadcastsd(mem(rax, r13, 1), ymm3) vfmadd231pd(ymm0, ymm2, ymm8) - vfmadd231pd(ymm1, ymm2, ymm9) vfmadd231pd(ymm0, ymm3, ymm10) - vfmadd231pd(ymm1, ymm3, ymm11) - + vbroadcastsd(mem(rax, r8, 4), ymm2) vbroadcastsd(mem(rax, r15, 1), ymm3) add(r9, rax) // a += cs_a; vfmadd231pd(ymm0, ymm2, ymm12) - vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(xmm1, xmm2, xmm13) vfmadd231pd(ymm0, ymm3, ymm14) - vfmadd231pd(ymm1, ymm3, ymm15) - + vfmadd231pd(xmm1, xmm3, xmm15) + + + // ---------------------------------- iteration 3 + + prefetch(0, mem(rdx, rcx, 1, 5*8)) + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm4) + vfmadd231pd(xmm0, xmm3, xmm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm0, ymm3, ymm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(xmm1, xmm2, xmm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(xmm1, xmm3, xmm15) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm4) + vfmadd231pd(xmm0, xmm3, xmm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm0, ymm3, ymm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(xmm1, xmm2, xmm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(xmm1, xmm3, xmm15) + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(r12, rcx) // reset rcx to current utile of c. + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm6, ymm6) + vmulpd(ymm0, ymm8, ymm8) + vmulpd(ymm0, ymm10, ymm10) + vmulpd(ymm0, ymm12, ymm12) + vmulpd(ymm0, ymm13, ymm13) + vmulpd(ymm0, ymm14, ymm14) + vmulpd(ymm0, ymm15, ymm15) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx, 0*32), xmm3, xmm4) + vmovlpd(xmm4, mem(rcx, 0*32)) // write back only lower half of xmm (8 bytes) + + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), xmm3, xmm6) + vmovupd(xmm6, mem(rcx, 0*32)) // write only lower half of ymm6 to c + + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm8) + vmovupd(xmm8, mem(rcx, 0*32)) // write lower half of ymm (16 bytes) + vextractf128(imm(1), ymm8, xmm1) // move upper half of ymm to xmm + vmovlpd(xmm1, mem(rcx, 2*8)) // write only lower half of xmm (8 bytes) to rcx + 16 + + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm10) + vmovupd(ymm10, mem(rcx, 0*32)) + + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm12) + vmovupd(ymm12, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), xmm3, xmm13) + vmovlpd(xmm13, mem(rcx, 1*32)) // write back only xmm13[0] + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm14) + vmovupd(ymm14, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), xmm3, xmm15) + vmovupd(xmm15, mem(rcx, 1*32)) // write xmm to c (16 bytes) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + // begin I/O on columns 0-3 + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx ), ymm3, ymm4) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm6) + vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm8) + vfmadd231pd(mem(rcx, rax, 1), ymm3, ymm10) + vmovupd(ymm4, mem(rcx )) + vextractf128(imm(1), ymm6, xmm1) // move upper half of ymm to xmm1 (ymm6[2], ymm6[3]) + vmovhpd(xmm6, mem(rcx, rsi, 1, 1*8)) // write upper half of xmm6(ymm6[1]) to c + rsi + 8 + vmovupd(xmm1, mem(rcx, rsi, 1, 2*8)) // write xmm1 (ymm6[2], ymm6[3]) to c + rsi + 16 + vextractf128(imm(1), ymm8, xmm1) // move upper half of ymm8 to xmm1 + vmovupd(xmm1, mem(rcx, rsi, 2, 2*8)) // write upper half of ymm8 to c + rsi*2 + 16 + vextractf128(imm(1), ymm10, xmm1) // move uppper half of ymm10 to xmm1 + vmovhpd(xmm1, mem(rcx, rax, 1, 3*8)) // move ymm8[3] to c + rsi*3 + 3*8 + + lea(mem(rcx, rsi, 4), rcx) + + vunpcklpd(ymm14, ymm12, ymm0) + vunpckhpd(ymm14, ymm12, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vfmadd231pd(mem(rdx ), xmm3, xmm0) + vfmadd231pd(mem(rdx, rsi, 1), xmm3, xmm1) + vfmadd231pd(mem(rdx, rsi, 2), xmm3, xmm2) + vfmadd231pd(mem(rdx, rax, 1), xmm3, xmm4) + vmovupd(xmm0, mem(rdx )) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovupd(xmm4, mem(rdx, rax, 1)) + + lea(mem(rdx, rsi, 4), rdx) + + vunpcklpd(ymm15, ymm13, ymm0) + vunpckhpd(ymm15, ymm13, ymm1) + + vfmadd231pd(mem(rdx ), xmm3, xmm0) + vfmadd231pd(mem(rdx, rsi, 1), xmm3, xmm1) + vmovupd(xmm0, mem(rdx )) // move the first half of ymm13 to c + vmovhpd(xmm1, mem(rdx, rsi, 1, 1*8)) // move the last 8 bits of ymm13 + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmovlpd(xmm4, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovupd(xmm6, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovupd(xmm8, mem(rcx, 0*32)) + vextractf128(imm(1), ymm8, xmm1) + vmovlpd(xmm1, mem(rcx, 2*8)) + add(rdi, rcx) + + + vmovupd(ymm10, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovupd(ymm12, mem(rcx, 0*32)) + vmovlpd(xmm13, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovupd(ymm14, mem(rcx, 0*32)) + vmovupd(xmm15, mem(rcx, 1*32)) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + // begin I/O on columns 0-3 + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vbroadcastsd(mem(rbx), ymm3) + + vmovupd(ymm4, mem(rcx )) + vextractf128(imm(1), ymm6, xmm1) // move upper half of ymm to xmm1 (ymm6[2], ymm6[3]) + vmovhpd(xmm6, mem(rcx, rsi, 1, 1*8)) // write upper half of xmm6(ymm6[1]) to c + rsi + 8 + vmovupd(xmm1, mem(rcx, rsi, 1, 2*8)) // write xmm1 (ymm6[2], ymm6[3]) to c + rsi + 16 + vextractf128(imm(1), ymm8, xmm1) // move upper half of ymm8 to xmm1 + vmovupd(xmm1, mem(rcx, rsi, 2, 2*8)) // write upper half of ymm8 to c + rsi*2 + 16 + vextractf128(imm(1), ymm10, xmm1) // move uppper half of ymm10 to xmm1 + vmovhpd(xmm1, mem(rcx, rax, 1, 3*8)) // move ymm8[3] to c + rsi*3 + 3*8 + + lea(mem(rcx, rsi, 4), rcx) + + vunpcklpd(ymm14, ymm12, ymm0) + vunpckhpd(ymm14, ymm12, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vmovupd(xmm0, mem(rdx )) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovupd(xmm4, mem(rdx, rax, 1)) + + lea(mem(rdx, rsi, 4), rdx) + + vunpcklpd(ymm15, ymm13, ymm0) + vunpckhpd(ymm15, ymm13, ymm1) + + vmovupd(xmm0, mem(rdx )) // move the first half of ymm13 to c + vmovhpd(xmm1, mem(rdx, rsi, 1, 1*8)) // move the last 8 bits of ymm13 + + + label(.DDONE) + vzeroupper() + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [ps_a8] "m" (ps_a8), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7); +} + +/* + +Following kernel computes the 6x8 block for the Lower vairant(L) of gemmt where +m_offset in 24x24 block is 6 and n_offset is 8(6x8) +(6x8)_L + +the region marked with 'x' is computed by following kernel +the region marked with '-' is not computed + + <-- n_off_24 -- > + 8 9 10 11 12 13 14 15 + +ā†‘ 6 - - - - - - - - +| 7 - - - - - - - - +m 8 x - - - - - - - +off 9 x x - - - - - - +24 10 x x x - - - - - +| 11 x x x x - - - - +ā†“ + + +*/ +void bli_dgemmsup_rv_haswell_asm_6x8m_6x8_L + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_7); + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + + // ------------------------------------------------------------------------- + + begin_asm() + mov(var(a), r14) + mov(var(c), r12) + + mov(var(rs_a), r8) + mov(var(cs_a), r9) + mov(var(rs_b), r10) + mov(var(rs_c), rdi) + lea(mem(, r8, 8), r8) + lea(mem(, r9, 8), r9) + lea(mem(, r10, 8), r10) + lea(mem(, rdi, 8), rdi) + + lea(mem(r8, r8, 2), r13) //3*r8 + lea(mem(r8, r8, 4), r15) //5*r8 + + vxorpd(ymm8, ymm8, ymm8) + vmovapd( ymm8, ymm10) + vmovapd( ymm8, ymm12) + vmovapd( ymm8, ymm14) + mov(var(b), rbx) // load address of b. + mov(r14, rax) + + cmp(imm(8), rdi) + jz(.DCOLPFETCH) + + label(.DROWPFETCH) + lea(mem(r12, rdi, 2), rdx) + lea(mem(rdx, rdi, 1), rdx) + prefetch(0, mem(r12, rdi, 2, 1*8)) + prefetch(0, mem(rdx, 2*8)) + prefetch(0, mem(rdx, rdi, 1, 3*8)) + prefetch(0, mem(rdx, rdi, 2, 4*8)) + jmp(.DPOSTPFETCH) + + label(.DCOLPFETCH) + mov(var(cs_c), rsi) + lea(mem(, rsi, 8), rsi) + lea(mem(r12, rsi, 2), rdx) + lea(mem(rdx, rsi, 1), rdx) + prefetch(0, mem(r12, 5*8)) + prefetch(0, mem(r12, rsi, 1, 5*8)) + prefetch(0, mem(r12, rsi, 2, 5*8)) + prefetch(0, mem(rdx, 5*8)) + + label(.DPOSTPFETCH) + mov(var(k_iter), rsi) + test(rsi, rsi) + je(.DCONSIDKLEFT) + + // computer xmm8, xmm10, ymm12, ymm14 only + label(.DLOOPKITER) + //0 + vmovupd(mem(rbx, 0*32), ymm0) + vbroadcastsd(mem(rax, r8, 2), ymm1) + vbroadcastsd(mem(rax, r13, 1), ymm2) + vbroadcastsd(mem(rax, r8, 4), ymm3) + vbroadcastsd(mem(rax, r15, 1), ymm4) + vfmadd231pd(xmm0, xmm1, xmm8) + vfmadd231pd(xmm0, xmm2, xmm10) + vfmadd231pd(ymm0, ymm3, ymm12) + vfmadd231pd(ymm0, ymm4, ymm14) + add(r10, rbx) // b += rs_b; + add(r9, rax) // a += cs_a; + //1 + vmovupd(mem(rbx, 0*32), ymm0) + vbroadcastsd(mem(rax, r8, 2), ymm1) + vbroadcastsd(mem(rax, r13, 1), ymm2) + vbroadcastsd(mem(rax, r8, 4), ymm3) + vbroadcastsd(mem(rax, r15, 1), ymm4) + vfmadd231pd(xmm0, xmm1, xmm8) + vfmadd231pd(xmm0, xmm2, xmm10) + vfmadd231pd(ymm0, ymm3, ymm12) + vfmadd231pd(ymm0, ymm4, ymm14) + add(r10, rbx) // b += rs_b; + add(r9, rax) // a += cs_a; + //2 + vmovupd(mem(rbx, 0*32), ymm0) + vbroadcastsd(mem(rax, r8, 2), ymm1) + vbroadcastsd(mem(rax, r13, 1), ymm2) + vbroadcastsd(mem(rax, r8, 4), ymm3) + vbroadcastsd(mem(rax, r15, 1), ymm4) + vfmadd231pd(xmm0, xmm1, xmm8) + vfmadd231pd(xmm0, xmm2, xmm10) + vfmadd231pd(ymm0, ymm3, ymm12) + vfmadd231pd(ymm0, ymm4, ymm14) + add(r10, rbx) // b += rs_b; + add(r9, rax) // a += cs_a; + //3 + vmovupd(mem(rbx, 0*32), ymm0) + vbroadcastsd(mem(rax, r8, 2), ymm1) + vbroadcastsd(mem(rax, r13, 1), ymm2) + vbroadcastsd(mem(rax, r8, 4), ymm3) + vbroadcastsd(mem(rax, r15, 1), ymm4) + vfmadd231pd(xmm0, xmm1, xmm8) + vfmadd231pd(xmm0, xmm2, xmm10) + vfmadd231pd(ymm0, ymm3, ymm12) + vfmadd231pd(ymm0, ymm4, ymm14) + add(r10, rbx) // b += rs_b; + add(r9, rax) // a += cs_a; + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + label(.DCONSIDKLEFT) + mov(var(k_left), rsi) + test(rsi, rsi) + je(.DPOSTACCUM) + + label(.DLOOPKLEFT) + vmovupd(mem(rbx, 0*32), ymm0) + vbroadcastsd(mem(rax, r8, 2), ymm1) + vbroadcastsd(mem(rax, r13, 1), ymm2) + vbroadcastsd(mem(rax, r8, 4), ymm3) + vbroadcastsd(mem(rax, r15, 1), ymm4) + vfmadd231pd(xmm0, xmm1, xmm8) + vfmadd231pd(xmm0, xmm2, xmm10) + vfmadd231pd(ymm0, ymm3, ymm12) + vfmadd231pd(ymm0, ymm4, ymm14) + add(r10, rbx) // b += rs_b; + add(r9, rax) // a += cs_a; + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + label(.DPOSTACCUM) + + mov(r12, rcx) // reset rcx to current utile of c. + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm8, ymm8) + vmulpd(ymm0, ymm10, ymm10) + vmulpd(ymm0, ymm12, ymm12) + vmulpd(ymm0, ymm14, ymm14) + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + label(.DROWSTORED) + lea(mem(rcx , rdi, 2), rcx) + vfmadd231pd(mem(rcx, 0*32), xmm3, xmm8) + vmovlpd(xmm8, mem(rcx, 0*32)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx, 0*32), xmm3, xmm10) + vmovupd(xmm10, mem(rcx, 0*32)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm12) + vmovupd(xmm12, mem(rcx, 0*32)) + vextractf128(imm(1), ymm12, xmm1) + vmovlpd(xmm1, mem(rcx, 2*8)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm14) + vmovupd(ymm14, mem(rcx, 0*32)) + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + // begin I/O on columns 0-3 + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx ), ymm3, ymm4) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm6) + + vextractf128(imm(1), ymm4, xmm1) + vmovupd(xmm1, mem(rcx, 2*8 )) // write upper half of ymm4 to c + vextractf128(imm(1), ymm6, xmm1) + vmovhpd(xmm1, mem(rcx, rsi, 1, 3*8)) // write last element of ymm6 + + lea(mem(rcx, rsi, 4), rcx) + + vunpcklpd(ymm14, ymm12, ymm0) + vunpckhpd(ymm14, ymm12, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vfmadd231pd(mem(rdx ), xmm3, xmm0) + vfmadd231pd(mem(rdx, rsi, 1), xmm3, xmm1) + vfmadd231pd(mem(rdx, rsi, 2), xmm3, xmm2) + vfmadd231pd(mem(rdx, rax, 1), xmm3, xmm4) + vmovupd(xmm0, mem(rdx )) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovhpd(xmm4, mem(rdx, rax, 1, 1*8)) // write only last 8 bytes of second half of ymm14 + + lea(mem(rdx, rsi, 4), rdx) + + jmp(.DDONE) // jump to end. + + label(.DBETAZERO) + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + label(.DROWSTORBZ) + lea(mem(rcx , rdi, 2), rcx) + + vmovlpd(xmm8, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovupd(xmm10, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovupd(xmm12, mem(rcx, 0*32)) + vextractf128(imm(1), ymm12, xmm1) + vmovlpd(xmm1, mem(rcx, 2*8)) + add(rdi, rcx) + + + vmovupd(ymm14, mem(rcx, 0*32)) + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + // begin I/O on columns 0-3 + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + + vbroadcastsd(mem(rbx), ymm3) + + vextractf128(imm(1), ymm4, xmm1) + vmovupd(xmm1, mem(rcx, 2*8 )) // write upper half of ymm4 to c + vextractf128(imm(1), ymm6, xmm1) + vmovhpd(xmm1, mem(rcx, rsi, 1, 3*8)) // write last element of ymm6 + + lea(mem(rcx, rsi, 4), rcx) + + vunpcklpd(ymm14, ymm12, ymm0) + vunpckhpd(ymm14, ymm12, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vmovupd(xmm0, mem(rdx )) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovhpd(xmm4, mem(rdx, rax, 1, 1*8)) // write only last 8 bytes of second half of ymm14 + + + label(.DDONE) + vzeroupper() + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7); +} + +/* + +Following kernel computes the 6x8 block for the Lower vairant(L) of gemmt where +m_offset in 24x24 block is 12 and n_offset is 16(12x16) +(12x16)_L + + +the region marked with 'x' is computed by following kernel +the region marked with '-' is not computed + + <-- n_off_24 -- > + 16 17 18 19 20 21 22 23 + +ā†‘ 12 - - - - - - - - +| 13 - - - - - - - - +m 14 - - - - - - - - +off 15 - - - - - - - - +24 16 x - - - - - - - +| 17 x x - - - - - - +ā†“ + + +*/ +void bli_dgemmsup_rv_haswell_asm_6x8m_12x16_L + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + begin_asm() + mov(var(a), r14) + mov(var(b), rbx) + mov(var(c), r12) + mov(r14, rax) + + mov(var(rs_a), r8) + mov(var(cs_a), r9) + lea(mem(, r8, 8), r8) + lea(mem(, r9, 8), r9) + + mov(var(rs_b), r10) + lea(mem(, r10, 8), r10) + + mov(var(rs_c), rdi) + lea(mem(, rdi, 8), rdi) + + lea(mem(r8, r8, 4), r15) + + vxorpd(ymm12, ymm12, ymm12) + vmovapd(ymm12, ymm14) + + cmp(imm(8), rdi) + jz(.DCOLPFETCH) + + label(.DROWPFETCH) + lea(mem(r12, rdi, 2), rdx) + lea(mem(rdx, rdi, 1), rdx) + prefetch(0, mem(rdx, rdi, 1, 1*8)) + prefetch(0, mem(rdx, rdi, 2, 2*8)) + jmp(.DPOSTPFETCH) + + label(.DCOLPFETCH) + mov(var(cs_c), rsi) + lea(mem(, rsi, 8), rsi) + prefetch(0, mem(r12, 5*8)) + prefetch(0, mem(r12, rsi, 1, 5*8)) + + label(.DPOSTPFETCH) + lea(mem(rax, r8, 4), rax) + mov(var(k_iter), rsi) + test(rsi, rsi) + je(.DCONSILEFT) + + //compute xmm12 and xmm 14 + label(.DMAIN) + //0 + vmovupd(mem(rbx, 0*32), xmm0) + vbroadcastsd(mem(rax, r8, 4), ymm3) + vbroadcastsd(mem(rax, r15, 1), ymm4) + vfmadd231pd(xmm0, xmm3, xmm12) + vfmadd231pd(xmm0, xmm4, xmm14) + add(r10, rbx) + add(r9, rax) + //1 + vmovupd(mem(rbx, 0*32), xmm0) + vbroadcastsd(mem(rax, r8, 4), ymm3) + vbroadcastsd(mem(rax, r15, 1), ymm4) + vfmadd231pd(xmm0, xmm3, xmm12) + vfmadd231pd(xmm0, xmm4, xmm14) + add(r10, rbx) + add(r9, rax) + //2 + vmovupd(mem(rbx, 0*32), xmm0) + vbroadcastsd(mem(rax, r8, 4), ymm3) + vbroadcastsd(mem(rax, r15, 1), ymm4) + vfmadd231pd(xmm0, xmm3, xmm12) + vfmadd231pd(xmm0, xmm4, xmm14) + add(r10, rbx) + add(r9, rax) + //3 + vmovupd(mem(rbx, 0*32), xmm0) + vbroadcastsd(mem(rax, r8, 4), ymm3) + vbroadcastsd(mem(rax, r15, 1), ymm4) + vfmadd231pd(xmm0, xmm3, xmm12) + vfmadd231pd(xmm0, xmm4, xmm14) + add(r10, rbx) + add(r9, rax) + + dec(rsi) + jne(.DMAIN) + + label(.DCONSILEFT) + mov(var(k_left), rsi) + test(rsi, rsi) + je(.DPOSTACC) + + label(.DLEFT) + vmovupd(mem(rbx, 0*32), xmm0) + vbroadcastsd(mem(rax, r8, 4), ymm3) + vbroadcastsd(mem(rax, r15, 1), ymm4) + vfmadd231pd(xmm0, xmm3, xmm12) + vfmadd231pd(xmm0, xmm4, xmm14) + add(r10, rbx) + add(r9, rax) + dec(rsi) + jne(.DLEFT) + + label(.DPOSTACC) + mov(r12, rcx) + mov(var(alpha), rax) + mov(var(beta), rbx) + vbroadcastsd(mem(rax), ymm0) + vbroadcastsd(mem(rbx), ymm3) + vmulpd(ymm0, ymm12, ymm12) + vmulpd(ymm0, ymm14, ymm14) + + mov(var(cs_c), rsi) + lea(mem(, rsi, 8), rsi) + vxorpd(ymm0, ymm0, ymm0) + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) + + cmp(imm(8), rdi) //rs_c == 0? + je(.DCOLSTOR) + + label(.DROWSTOR) + lea(mem(rcx, rdi, 4), rcx) //rcx += 4 * rdi + vfmadd231pd(mem(rcx, 0*32), xmm3, xmm12) + vmovlpd(xmm12, mem(rcx)) + add(rdi, rcx) + vfmadd231pd(mem(rcx, 0*32), xmm3, xmm14) + vmovlpd(xmm14, mem(rcx)) + vmovhpd(xmm14, mem(rcx, rsi, 1)) + jmp(.DDONE) + + label(.DCOLSTOR) + + lea(mem(rcx, rdi, 4), rdx) + vunpcklpd(ymm14, ymm12, ymm0) + vunpckhpd(ymm14, ymm12, ymm1) + vinsertf128(imm(0x1), xmm2, ymm0, ymm12) + vinsertf128(imm(0x1), xmm3, ymm1, ymm14) + + vfmadd231pd(mem(rdx), xmm3, xmm12) + vfmadd231pd(mem(rdx, rsi, 1), xmm3, xmm14) + vmovupd(xmm12, mem(rdx )) + vmovhpd(xmm14, mem(rdx, rsi, 1, 1*8)) + jmp(.DDONE) + + label(.DBETAZERO) + cmp(imm(8), rdi) //rs_c == 0? + je(.DCOLSTORBZ) + + label(.DROWSTORBZ) + lea(mem(rcx, rdi, 4), rcx) //rcx += 4 * rdi + vmovlpd(xmm12, mem(rcx)) + add(rdi, rcx) + vmovlpd(xmm14, mem(rcx)) + vmovhpd(xmm14, mem(rcx, rsi, 1)) + jmp(.DDONE) + + label(.DCOLSTORBZ) + + lea(mem(rcx, rdi, 4), rdx) + vunpcklpd(ymm14, ymm12, ymm0) + vunpckhpd(ymm14, ymm12, ymm1) + vinsertf128(imm(0x1), xmm2, ymm0, ymm12) + vinsertf128(imm(0x1), xmm3, ymm1, ymm14) + + vmovupd(xmm12, mem(rdx )) + vmovhpd(xmm14, mem(rdx, rsi, 1, 1*8)) + jmp(.DDONE) + + label(.DDONE) + vzeroupper() + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +/* + +Following kernel computes the 6x8 block for the Lower vairant(L) of gemmt where +m_offset in 24x24 block is 12, n_offset is 16(12x16) and m_offset is 18, n_offset is 16 (18x16) +(16x12)+(16x18)_L + +the region marked with 'x' is computed by following kernel +the region marked with '-' is not computed + + <-- n_off_24 -- > + 16 17 18 19 20 21 22 23 + +ā†‘ 12 - - - - - - - - +| 13 - - - - - - - - +m 14 - - - - - - - - +off 15 - - - - - - - - +24 16 x - - - - - - - +| 17 x x - - - - - - +ā†“ +ā†‘ 18 x x x - - - - - +| 19 x x x x - - - - +m 20 x x x x x - - - +off 21 x x x x x x - - +24 22 x x x x x x x - +| 23 x x x x x x x x +ā†“ + + +*/ +void bli_dgemmsup_rv_haswell_asm_6x8m_16x12_combined_L + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) + { + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + uint64_t ps_a8 = bli_auxinfo_ps_a( data ) * sizeof( double ); + double* a_next = ( (double*)a ) + rs_a * 6; + begin_asm() + mov(var(a), r14) + mov(var(b), rbx) + mov(var(c), r12) + mov(var(a_next), r11) + mov(r14, rax) + + mov(var(rs_a), r8) + mov(var(cs_a), r9) + lea(mem(, r8, 8), r8) + lea(mem(, r9, 8), r9) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + lea(mem(r8, r8, 4), r15) // 5 + mov(var(rs_b), r10) + lea(mem(, r10, 8), r10) + + mov(var(rs_c), rdi) + lea(mem(, rdi, 8), rdi) + + lea(mem(r8, r8, 4), r15) + + vxorpd(ymm4, ymm4, ymm4) + vmovapd( ymm4, ymm5) + vmovapd( ymm4, ymm6) + vmovapd( ymm4, ymm7) + vmovapd( ymm4, ymm8) + vmovapd( ymm4, ymm9) + vmovapd( ymm4, ymm10) + vmovapd( ymm4, ymm11) + vmovapd( ymm4, ymm12) + vmovapd( ymm4, ymm13) + vmovapd( ymm4, ymm14) + vmovapd( ymm4, ymm15) + + cmp(imm(8), rdi) + jz(.DCOLPFETCH) + + label(.DROWPFETCH) + lea(mem(r12, rdi, 2), rdx) + lea(mem(rdx, rdi, 1), rdx) + + prefetch(0, mem(rdx, rdi, 1, 1*8)) // c + 4 * rs_c + prefetch(0, mem(rdx, rdi, 2, 2*8)) + lea(mem(rdx, rdi, 2), rdx) + lea(mem(rdx, rdi, 1), rdx) // c + 6 *rsc + prefetch(0, mem(rdx, 7*8)) + prefetch(0, mem(rdx, rdi, 1, 7*8)) + prefetch(0, mem(rdx, rdi, 2, 7*8)) + lea(mem(rdx, rdi, 2), rdx) + lea(mem(rdx, rdi, 1), rdx) + prefetch(0, mem(rdx, 7*8)) + prefetch(0, mem(rdx, rdi, 1, 7*8)) + prefetch(0, mem(rdx, rdi, 2, 7*8)) + + jmp(.DPOSTPFETCH) + + label(.DCOLPFETCH) + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + lea(mem(r12, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(r12, 11*8)) // prefetch c + 0*cs_c + prefetch(0, mem(r12, rsi, 1, 11*8)) // prefetch c + 1*cs_c + prefetch(0, mem(r12, rsi, 2, 11*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 11*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rdx, rsi, 1, 11*8)) // prefetch c + 4*cs_c + prefetch(0, mem(rdx, rsi, 2, 11*8)) // prefetch c + 5*cs_c + lea(mem(rdx, rsi, 2), rdx) // rdx = c + 5*cs_c; + prefetch(0, mem(rdx, rsi, 1, 11*8)) // prefetch c + 6*cs_c + prefetch(0, mem(rdx, rsi, 2, 11*8)) // prefetch c + 7*cs_c + + label(.DPOSTPFETCH) + mov(var(ps_a8), rdx) + lea(mem(rax, rdx, 1), rdx) //rdx = a + ps_a8 //for prefetch + mov(var(ps_a8), rbp) + lea(mem(r11, rbp, 1), rbp) //rdx = a + ps_a8 //for prefetch + mov(var(k_iter), rsi) + test(rsi, rsi) + je(.DCONSILEFT) + + // ymm5 and ymm7 contains the data for 16x12 block, other registers contains data for 16x18 block + label(.DMAIN) + //0 + prefetch(0, mem(rdx, 5*8)) + prefetch(0, mem(rbp, 5*8)) + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm5) + vfmadd231pd(xmm0, xmm3, xmm7) + vbroadcastsd(mem(r11 ), ymm2) + vbroadcastsd(mem(r11, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + vbroadcastsd(mem(r11, r8, 2), ymm2) + vbroadcastsd(mem(r11, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(r11, r8, 4), ymm2) + vbroadcastsd(mem(r11, r15, 1), ymm3) + add(r9, r11) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + add(r10, rbx) + add(r9, rax) + //1 + prefetch(0, mem(rdx, r9, 1, 5*8)) + prefetch(0, mem(rbp, r9, 1, 5*8)) + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm5) + vfmadd231pd(xmm0, xmm3, xmm7) + vbroadcastsd(mem(r11 ), ymm2) + vbroadcastsd(mem(r11, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + vbroadcastsd(mem(r11, r8, 2), ymm2) + vbroadcastsd(mem(r11, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(r11, r8, 4), ymm2) + vbroadcastsd(mem(r11, r15, 1), ymm3) + add(r9, r11) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + add(r10, rbx) + add(r9, rax) + //2 + prefetch(0, mem(rdx, r9, 2, 5*8)) + prefetch(0, mem(rbp, r9, 2, 5*8)) + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm5) + vfmadd231pd(xmm0, xmm3, xmm7) + vbroadcastsd(mem(r11 ), ymm2) + vbroadcastsd(mem(r11, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + vbroadcastsd(mem(r11, r8, 2), ymm2) + vbroadcastsd(mem(r11, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(r11, r8, 4), ymm2) + vbroadcastsd(mem(r11, r15, 1), ymm3) + add(r9, r11) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + add(r10, rbx) + add(r9, rax) + //3 + prefetch(0, mem(rdx, rcx, 1, 5*8)) + prefetch(0, mem(rbp, rcx, 1, 5*8)) + lea(mem(rdx, r9, 4), rdx) + lea(mem(rbp, r9, 4), rbp) + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm5) + vfmadd231pd(xmm0, xmm3, xmm7) + vbroadcastsd(mem(r11 ), ymm2) + vbroadcastsd(mem(r11, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + vbroadcastsd(mem(r11, r8, 2), ymm2) + vbroadcastsd(mem(r11, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(r11, r8, 4), ymm2) + vbroadcastsd(mem(r11, r15, 1), ymm3) + add(r9, r11) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + add(r10, rbx) + add(r9, rax) + + dec(rsi) + jne(.DMAIN) + + label(.DCONSILEFT) + mov(var(k_left), rsi) + test(rsi, rsi) + je(.DPOSTACC) + + label(.DLEFT) + prefetch(0, mem(rdx, 5*8)) + prefetch(0, mem(rbp, 5*8)) + add(r9, rbp) + add(r9, rdx) + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm5) + vfmadd231pd(xmm0, xmm3, xmm7) + vbroadcastsd(mem(r11 ), ymm2) + vbroadcastsd(mem(r11, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + vbroadcastsd(mem(r11, r8, 2), ymm2) + vbroadcastsd(mem(r11, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(r11, r8, 4), ymm2) + vbroadcastsd(mem(r11, r15, 1), ymm3) + add(r9, r11) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + add(r10, rbx) + add(r9, rax) + + dec(rsi) + jne(.DLEFT) + + label(.DPOSTACC) + mov(r12, rcx) + mov(var(alpha), rax) + mov(var(beta), rbx) + vbroadcastsd(mem(rax), ymm0) + vbroadcastsd(mem(rbx), ymm3) + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm7, ymm7) + vmulpd(ymm0, ymm4, ymm4) + vmulpd(ymm0, ymm6, ymm6) + vmulpd(ymm0, ymm8, ymm8) + vmulpd(ymm0, ymm9, ymm9) + vmulpd(ymm0, ymm10, ymm10) + vmulpd(ymm0, ymm11, ymm11) + vmulpd(ymm0, ymm12, ymm12) + vmulpd(ymm0, ymm13, ymm13) + vmulpd(ymm0, ymm14, ymm14) + vmulpd(ymm0, ymm15, ymm15) + + mov(var(cs_c), rsi) + lea(mem(, rsi, 8), rsi) + vxorpd(ymm0, ymm0, ymm0) + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) + + cmp(imm(8), rdi) //rs_c == 8? + je(.DCOLSTOR) + + label(.DROWSTOR) + lea(mem(rcx, rdi, 4), rcx) //rcx += 4 * rdi + vfmadd231pd(mem(rcx, 0*32), xmm3, xmm5) + vmovlpd(xmm5, mem(rcx)) + add(rdi, rcx) + vfmadd231pd(mem(rcx, 0*32), xmm3, xmm7) + vmovlpd(xmm7, mem(rcx)) + vmovhpd(xmm7, mem(rcx, rsi, 1)) + + //for lower 6x8 + lea(mem(rcx, rdi, 1), rcx) //rcx += 1 * rdi + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm4) + vmovupd(xmm4, mem(rcx, 0*32)) + vextractf128(imm(1), ymm4, xmm1) + vmovlpd(xmm1, mem(rcx, 2*8)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm6) + vmovupd(ymm6, mem(rcx, 0*32)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm8) + vmovupd(ymm8, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm9) + vmovlpd(xmm9, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm10) + vmovupd(ymm10, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm11) + vmovupd(xmm11, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm12) + vmovupd(ymm12, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm13) + vmovupd(xmm13, mem(rcx, 1*32)) + vextractf128(imm(1), ymm13, xmm1) + vmovlpd(xmm1, mem(rcx, 1*32+2*8)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm14) + vmovupd(ymm14, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm15) + vmovupd(ymm15, mem(rcx, 1*32)) + + jmp(.DDONE) + + label(.DCOLSTOR) + vbroadcastsd(mem(rbx), ymm3) + + lea(mem(rcx, rdi, 4), rdx) //rdx = rcx + 4* rs_c + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + + vfmadd231pd(mem(rdx), xmm3, xmm5) + vfmadd231pd(mem(rdx, rsi, 1), xmm3, xmm7) + vmovupd(xmm5, mem(rdx )) + vmovhpd(xmm7, mem(rdx, rsi, 1, 1*8)) + + lea(mem(rcx, rdi, 4), rcx) + lea(mem(rcx, rdi, 2), rcx) + lea(mem(rcx, rdi, 4), rdx) + + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx ), ymm3, ymm4) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm6) + vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm8) + vfmadd231pd(mem(rcx, rax, 1), ymm3, ymm10) + vmovupd(ymm4, mem(rcx )) + vmovupd(ymm6, mem(rcx, rsi, 1)) + vmovupd(ymm8, mem(rcx, rsi, 2)) + vextractf128(imm(1), ymm10, xmm1) + vmovhpd(xmm10, mem(rcx, rax, 1, 1*8)) + vmovupd(xmm1, mem(rcx, rax, 1, 2*8)) + + + lea(mem(rcx, rsi, 4), rcx) + + vunpcklpd(ymm14, ymm12, ymm0) + vunpckhpd(ymm14, ymm12, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vfmadd231pd(mem(rdx ), xmm3, xmm0) + vfmadd231pd(mem(rdx, rsi, 1), xmm3, xmm1) + vfmadd231pd(mem(rdx, rsi, 2), xmm3, xmm2) + vfmadd231pd(mem(rdx, rax, 1), xmm3, xmm4) + vmovupd(xmm0, mem(rdx )) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovupd(xmm4, mem(rdx, rax, 1)) + + lea(mem(rdx, rsi, 4), rdx) + + // begin I/O on columns 4-7 + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx ), ymm3, ymm5) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm7) + vextractf128(imm(1), ymm5, xmm1) + vmovupd(xmm1, mem(rcx, 2*8 )) + vextractf128(imm(1), ymm7, xmm1) + vmovhpd(xmm1, mem(rcx, rsi, 1, 3*8)) + + vunpcklpd(ymm15, ymm13, ymm0) + vunpckhpd(ymm15, ymm13, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vfmadd231pd(mem(rdx ), xmm3, xmm0) + vfmadd231pd(mem(rdx, rsi, 1), xmm3, xmm1) + vfmadd231pd(mem(rdx, rsi, 2), xmm3, xmm2) + vfmadd231pd(mem(rdx, rax, 1), xmm3, xmm4) + vmovupd(xmm0, mem(rdx )) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovhpd(xmm4, mem(rdx, rax, 1, 1*8)) + jmp(.DDONE) + + label(.DBETAZERO) + cmp(imm(8), rdi) + je(.DCOLSTORBZ) + + label(.DROWSTORBZ) + lea(mem(rcx, rdi, 4), rcx) //rcx += 4 * rdi + vmovlpd(xmm5, mem(rcx)) + add(rdi, rcx) + vmovlpd(xmm7, mem(rcx)) + vmovhpd(xmm7, mem(rcx, rsi, 1)) + + //For lower 6x8 block + lea(mem(rcx, rdi, 1), rcx) //rcx += 1 * rdi + vmovupd(xmm4, mem(rcx, 0*32)) + vextractf128(imm(1), ymm4, xmm1) + vmovlpd(xmm1, mem(rcx, 2*8)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx, 0*32)) + add(rdi, rcx) + + vmovupd(ymm8, mem(rcx, 0*32)) + + vmovlpd(xmm9, mem(rcx, 1*32)) + add(rdi, rcx) + + vmovupd(ymm10, mem(rcx, 0*32)) + + vmovupd(xmm11, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovupd(ymm12, mem(rcx, 0*32)) + + vmovupd(xmm13, mem(rcx, 1*32)) + vextractf128(imm(1), ymm13, xmm1) + vmovlpd(xmm1, mem(rcx, 1*32+2*8)) + add(rdi, rcx) + + vmovupd(ymm14, mem(rcx, 0*32)) + + vmovupd(ymm15, mem(rcx, 1*32)) + + jmp(.DDONE) + + label(.DCOLSTORBZ) + vbroadcastsd(mem(rbx), ymm3) + + lea(mem(rcx, rdi, 4), rdx) //rdx = rcx + 4* rs_c + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + + vmovupd(xmm5, mem(rdx )) + vmovhpd(xmm7, mem(rdx, rsi, 1, 1*8)) + + lea(mem(rcx, rdi, 4), rcx) + lea(mem(rcx, rdi, 2), rcx) + lea(mem(rcx, rdi, 4), rdx) + + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vbroadcastsd(mem(rbx), ymm3) + + vmovupd(ymm4, mem(rcx )) + vmovupd(ymm6, mem(rcx, rsi, 1)) + vmovupd(ymm8, mem(rcx, rsi, 2)) + vextractf128(imm(1), ymm10, xmm1) + vmovhpd(xmm10, mem(rcx, rax, 1, 1*8)) + vmovupd(xmm1, mem(rcx, rax, 1, 2*8)) + + + lea(mem(rcx, rsi, 4), rcx) + + vunpcklpd(ymm14, ymm12, ymm0) + vunpckhpd(ymm14, ymm12, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vmovupd(xmm0, mem(rdx )) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovupd(xmm4, mem(rdx, rax, 1)) + + lea(mem(rdx, rsi, 4), rdx) + + // begin I/O on columns 4-7 + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + + vbroadcastsd(mem(rbx), ymm3) + + vextractf128(imm(1), ymm5, xmm1) + vmovupd(xmm1, mem(rcx, 2*8 )) + vextractf128(imm(1), ymm7, xmm1) + vmovhpd(xmm1, mem(rcx, rsi, 1, 3*8)) + + vunpcklpd(ymm15, ymm13, ymm0) + vunpckhpd(ymm15, ymm13, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vmovupd(xmm0, mem(rdx )) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovhpd(xmm4, mem(rdx, rax, 1, 1*8)) + jmp(.DDONE) + + label(.DDONE) + vzeroupper() + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [a_next] "m" (a_next), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [ps_a8] "m" (ps_a8), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + } +/* + +Following kernel computes the 6x8 block for the Lower vairant(L) of gemmt where +m_offset in 24x24 block is 6 and n_offset is 0(6x0) +(6x0)_L + + +the region marked with 'x' is computed by following kernel +the region marked with '-' is not computed + + <-- n_off_24 -- > + 0 1 2 3 4 5 6 7 + +ā†‘ 6 x x x x x x x - +| 7 x x x x x x x x +m 8 x x x x x x x x +off 9 x x x x x x x x +24 10 x x x x x x x x +| 11 x x x x x x x x +ā†“ + + +*/ +void bli_dgemmsup_rv_haswell_asm_6x8m_6x0_L + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_7); + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // Query the panel stride of A and convert it to units of bytes. + uint64_t ps_a8 = bli_auxinfo_ps_a( data ) * sizeof( double ); + + + // ------------------------------------------------------------------------- + + begin_asm() + + + mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(rs_b), r10) // load rs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + vxorpd(ymm4, ymm4, ymm4) + vmovapd( ymm4, ymm5) + vmovapd( ymm4, ymm6) + vmovapd( ymm4, ymm7) + vmovapd( ymm4, ymm8) + vmovapd( ymm4, ymm9) + vmovapd( ymm4, ymm10) + vmovapd( ymm4, ymm11) + vmovapd( ymm4, ymm12) + vmovapd( ymm4, ymm13) + vmovapd( ymm4, ymm14) + vmovapd( ymm4, ymm15) + + mov(var(b), rbx) // load address of b. + mov(r14, rax) // reset rax to current upanel of a. + + + + PREFETCH_C() + + label(.DPOSTPFETCH) // done prefetching c + + + mov(var(ps_a8), rdx) // load ps_a8 + lea(mem(rax, rdx, 1), rdx) // rdx = a + ps_a8 + lea(mem(r9, r9, 2), rcx) // rcx = 3*cs_a; + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + prefetch(0, mem(rdx, 5*8)) + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + // ---------------------------------- iteration 1 + + prefetch(0, mem(rdx, r9, 1, 5*8)) + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + // ---------------------------------- iteration 2 + + prefetch(0, mem(rdx, r9, 2, 5*8)) + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + // ---------------------------------- iteration 3 + + prefetch(0, mem(rdx, rcx, 1, 5*8)) + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(r12, rcx) // reset rcx to current utile of c. + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + vmulpd(ymm0, ymm7, ymm7) + vmulpd(ymm0, ymm8, ymm8) + vmulpd(ymm0, ymm9, ymm9) + vmulpd(ymm0, ymm10, ymm10) + vmulpd(ymm0, ymm11, ymm11) + vmulpd(ymm0, ymm12, ymm12) + vmulpd(ymm0, ymm13, ymm13) + vmulpd(ymm0, ymm14, ymm14) + vmulpd(ymm0, ymm15, ymm15) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm4) + vmovupd(ymm4, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm5) + vmovupd(xmm5, mem(rcx, 1*32)) + vextractf128(imm(1), ymm5, xmm1) + vmovlpd(xmm1, mem(rcx, 1*32+2*8)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm6) + vmovupd(ymm6, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm7) + vmovupd(ymm7, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm8) + vmovupd(ymm8, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm9) + vmovupd(ymm9, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm10) + vmovupd(ymm10, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm11) + vmovupd(ymm11, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm12) + vmovupd(ymm12, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm13) + vmovupd(ymm13, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm14) + vmovupd(ymm14, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm15) + vmovupd(ymm15, mem(rcx, 1*32)) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + // begin I/O on columns 0-3 + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx ), ymm3, ymm4) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm6) + vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm8) + vfmadd231pd(mem(rcx, rax, 1), ymm3, ymm10) + vmovupd(ymm4, mem(rcx )) + vmovupd(ymm6, mem(rcx, rsi, 1)) + vmovupd(ymm8, mem(rcx, rsi, 2)) + vmovupd(ymm10, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + vunpcklpd(ymm14, ymm12, ymm0) + vunpckhpd(ymm14, ymm12, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vfmadd231pd(mem(rdx ), xmm3, xmm0) + vfmadd231pd(mem(rdx, rsi, 1), xmm3, xmm1) + vfmadd231pd(mem(rdx, rsi, 2), xmm3, xmm2) + vfmadd231pd(mem(rdx, rax, 1), xmm3, xmm4) + vmovupd(xmm0, mem(rdx )) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovupd(xmm4, mem(rdx, rax, 1)) + + lea(mem(rdx, rsi, 4), rdx) + + // begin I/O on columns 4-7 + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + vperm2f128(imm(0x31), ymm2, ymm0, ymm9) + vperm2f128(imm(0x31), ymm3, ymm1, ymm11) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx ), ymm3, ymm5) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm7) + vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm9) + vfmadd231pd(mem(rcx, rax, 1), ymm3, ymm11) + vmovupd(ymm5, mem(rcx )) + vmovupd(ymm7, mem(rcx, rsi, 1)) + vmovupd(ymm9, mem(rcx, rsi, 2)) + vmovhpd(xmm11, mem(rcx, rax, 1, 1*8)) + vextractf128(imm(1), ymm11, xmm1) + vmovupd(xmm1, mem(rcx, rax, 1, 2*8)) + + + vunpcklpd(ymm15, ymm13, ymm0) + vunpckhpd(ymm15, ymm13, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vfmadd231pd(mem(rdx ), xmm3, xmm0) + vfmadd231pd(mem(rdx, rsi, 1), xmm3, xmm1) + vfmadd231pd(mem(rdx, rsi, 2), xmm3, xmm2) + vfmadd231pd(mem(rdx, rax, 1), xmm3, xmm4) + vmovupd(xmm0, mem(rdx )) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovupd(xmm4, mem(rdx, rax, 1)) + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx, 0*32)) + vmovupd(xmm5, mem(rcx, 1*32)) + vextractf128(imm(1), ymm5, xmm1) + vmovlpd(xmm1, mem(rcx, 1*32+2*8)) + add(rdi, rcx) + + + vmovupd(ymm6, mem(rcx, 0*32)) + vmovupd(ymm7, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovupd(ymm8, mem(rcx, 0*32)) + vmovupd(ymm9, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovupd(ymm10, mem(rcx, 0*32)) + vmovupd(ymm11, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovupd(ymm12, mem(rcx, 0*32)) + vmovupd(ymm13, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovupd(ymm14, mem(rcx, 0*32)) + vmovupd(ymm15, mem(rcx, 1*32)) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + // begin I/O on columns 0-3 + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vmovupd(ymm4, mem(rcx )) + vmovupd(ymm6, mem(rcx, rsi, 1)) + vmovupd(ymm8, mem(rcx, rsi, 2)) + vmovupd(ymm10, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + vunpcklpd(ymm14, ymm12, ymm0) + vunpckhpd(ymm14, ymm12, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vmovupd(xmm0, mem(rdx )) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovupd(xmm4, mem(rdx, rax, 1)) + + lea(mem(rdx, rsi, 4), rdx) + + // begin I/O on columns 4-7 + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + vperm2f128(imm(0x31), ymm2, ymm0, ymm9) + vperm2f128(imm(0x31), ymm3, ymm1, ymm11) + + vbroadcastsd(mem(rbx), ymm3) + + vmovupd(ymm5, mem(rcx )) + vmovupd(ymm7, mem(rcx, rsi, 1)) + vmovupd(ymm9, mem(rcx, rsi, 2)) + vmovhpd(xmm11, mem(rcx, rax, 1, 1*8)) + vextractf128(imm(1), ymm11, xmm1) + vmovupd(xmm1, mem(rcx, rax, 1, 2*8)) + + + vunpcklpd(ymm15, ymm13, ymm0) + vunpckhpd(ymm15, ymm13, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vmovupd(xmm0, mem(rdx )) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovupd(xmm4, mem(rdx, rax, 1)) + + label(.DDONE) + vzeroupper() + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [ps_a8] "m" (ps_a8), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + + // Handle edge cases in the m dimension, if they exist. + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7); +} + +/* + +Following kernel computes the 6x8 block for the Lower vairant(L) of gemmt where +m_offset in 24x24 block is 12 and n_offset is 8(12x8) +(12x8)_L + +the region marked with 'x' is computed by following kernel +the region marked with '-' is not computed + + <-- n_off_24 -- > + 8 9 10 11 12 13 14 15 + +ā†‘ 12 x x x x x - - - +| 13 x x x x x x - - +m 14 x x x x x x x - +off 15 x x x x x x x x +24 16 x x x x x x x x +| 17 x x x x x x x x +ā†“ + + +*/ +void bli_dgemmsup_rv_haswell_asm_6x8m_12x8_L + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_7); + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // Query the panel stride of A and convert it to units of bytes. + uint64_t ps_a8 = bli_auxinfo_ps_a( data ) * sizeof( double ); + + + // ------------------------------------------------------------------------- + + begin_asm() + + mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(rs_b), r10) // load rs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vmovapd( ymm4, ymm5) + vmovapd( ymm4, ymm6) + vmovapd( ymm4, ymm7) + vmovapd( ymm4, ymm8) + vmovapd( ymm4, ymm9) + vmovapd( ymm4, ymm10) + vmovapd( ymm4, ymm11) + vmovapd( ymm4, ymm12) + vmovapd( ymm4, ymm13) + vmovapd( ymm4, ymm14) + vmovapd( ymm4, ymm15) + + mov(var(b), rbx) // load address of b. + mov(r14, rax) // reset rax to current upanel of a. + + + + PREFETCH_C() + + label(.DPOSTPFETCH) // done prefetching c + + + mov(var(ps_a8), rdx) // load ps_a8 + lea(mem(rax, rdx, 1), rdx) // rdx = a + ps_a8 + lea(mem(r9, r9, 2), rcx) // rcx = 3*cs_a; + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + prefetch(0, mem(rdx, 5*8)) + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + // ---------------------------------- iteration 1 + + prefetch(0, mem(rdx, r9, 1, 5*8)) + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + // ---------------------------------- iteration 2 + + prefetch(0, mem(rdx, r9, 2, 5*8)) + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + // ---------------------------------- iteration 3 + + prefetch(0, mem(rdx, rcx, 1, 5*8)) + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(r12, rcx) // reset rcx to current utile of c. + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + vmulpd(ymm0, ymm7, ymm7) + vmulpd(ymm0, ymm8, ymm8) + vmulpd(ymm0, ymm9, ymm9) + vmulpd(ymm0, ymm10, ymm10) + vmulpd(ymm0, ymm11, ymm11) + vmulpd(ymm0, ymm12, ymm12) + vmulpd(ymm0, ymm13, ymm13) + vmulpd(ymm0, ymm14, ymm14) + vmulpd(ymm0, ymm15, ymm15) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm4) + vmovupd(ymm4, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm5) + vmovlpd(xmm5, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm6) + vmovupd(ymm6, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm7) + vmovupd(xmm7, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm8) + vmovupd(ymm8, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm9) + vmovupd(xmm9, mem(rcx, 1*32)) + vextractf128(imm(1), ymm9, xmm1) + vmovlpd(xmm1, mem(rcx, 1*32+2*8)) + + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm10) + vmovupd(ymm10, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm11) + vmovupd(ymm11, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm12) + vmovupd(ymm12, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm13) + vmovupd(ymm13, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm14) + vmovupd(ymm14, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm15) + vmovupd(ymm15, mem(rcx, 1*32)) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + // begin I/O on columns 0-3 + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx ), ymm3, ymm4) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm6) + vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm8) + vfmadd231pd(mem(rcx, rax, 1), ymm3, ymm10) + vmovupd(ymm4, mem(rcx )) + vmovupd(ymm6, mem(rcx, rsi, 1)) + vmovupd(ymm8, mem(rcx, rsi, 2)) + vmovupd(ymm10, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + vunpcklpd(ymm14, ymm12, ymm0) + vunpckhpd(ymm14, ymm12, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vfmadd231pd(mem(rdx ), xmm3, xmm0) + vfmadd231pd(mem(rdx, rsi, 1), xmm3, xmm1) + vfmadd231pd(mem(rdx, rsi, 2), xmm3, xmm2) + vfmadd231pd(mem(rdx, rax, 1), xmm3, xmm4) + vmovupd(xmm0, mem(rdx )) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovupd(xmm4, mem(rdx, rax, 1)) + + lea(mem(rdx, rsi, 4), rdx) + + // begin I/O on columns 4-7 + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + vperm2f128(imm(0x31), ymm2, ymm0, ymm9) + vperm2f128(imm(0x31), ymm3, ymm1, ymm11) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx ), ymm3, ymm5) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm7) + vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm9) + vfmadd231pd(mem(rcx, rax, 1), ymm3, ymm11) + vmovupd(ymm5, mem(rcx )) + vextractf128(imm(1), ymm7, xmm1) + vmovhpd(xmm7, mem(rcx, rsi, 1, 1*8)) + vmovupd(xmm1, mem(rcx, rsi, 1, 2*8)) + vextractf128(imm(1), ymm9, xmm1) + vmovupd(xmm1, mem(rcx, rsi, 2, 2*8)) + vextractf128(imm(1), ymm11, xmm1) + vmovhpd(xmm1, mem(rcx, rax, 1, 3*8)) + + + vunpcklpd(ymm15, ymm13, ymm0) + vunpckhpd(ymm15, ymm13, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vfmadd231pd(mem(rdx ), xmm3, xmm0) + vfmadd231pd(mem(rdx, rsi, 1), xmm3, xmm1) + vfmadd231pd(mem(rdx, rsi, 2), xmm3, xmm2) + vfmadd231pd(mem(rdx, rax, 1), xmm3, xmm4) + vmovupd(xmm0, mem(rdx )) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovupd(xmm4, mem(rdx, rax, 1)) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx, 0*32)) + vmovlpd(xmm5, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovupd(ymm6, mem(rcx, 0*32)) + vmovupd(xmm7, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovupd(ymm8, mem(rcx, 0*32)) + vmovupd(xmm9, mem(rcx, 1*32)) + vextractf128(imm(1), ymm9, xmm1) + vmovlpd(xmm1, mem(rcx, 1*32+2*8)) + add(rdi, rcx) + + + vmovupd(ymm10, mem(rcx, 0*32)) + vmovupd(ymm11, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovupd(ymm12, mem(rcx, 0*32)) + vmovupd(ymm13, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovupd(ymm14, mem(rcx, 0*32)) + vmovupd(ymm15, mem(rcx, 1*32)) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + // begin I/O on columns 0-3 + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vmovupd(ymm4, mem(rcx )) + vmovupd(ymm6, mem(rcx, rsi, 1)) + vmovupd(ymm8, mem(rcx, rsi, 2)) + vmovupd(ymm10, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + vunpcklpd(ymm14, ymm12, ymm0) + vunpckhpd(ymm14, ymm12, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vmovupd(xmm0, mem(rdx )) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovupd(xmm4, mem(rdx, rax, 1)) + + lea(mem(rdx, rsi, 4), rdx) + + // begin I/O on columns 4-7 + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + vperm2f128(imm(0x31), ymm2, ymm0, ymm9) + vperm2f128(imm(0x31), ymm3, ymm1, ymm11) + + vmovupd(ymm5, mem(rcx )) + vextractf128(imm(1), ymm7, xmm1) + vmovhpd(xmm7, mem(rcx, rsi, 1, 1*8)) + vmovupd(xmm1, mem(rcx, rsi, 1, 2*8)) + vextractf128(imm(1), ymm9, xmm1) + vmovupd(xmm1, mem(rcx, rsi, 2, 2*8)) + vextractf128(imm(1), ymm11, xmm1) + vmovhpd(xmm1, mem(rcx, rax, 1, 3*8)) + + + vunpcklpd(ymm15, ymm13, ymm0) + vunpckhpd(ymm15, ymm13, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vmovupd(xmm0, mem(rdx )) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovupd(xmm4, mem(rdx, rax, 1)) + + label(.DDONE) + vzeroupper() + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [ps_a8] "m" (ps_a8), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + + // Handle edge cases in the m dimension, if they exist. + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7); +} +/* + +Following kernel computes the 6x8 block for the Lower vairant(L) of gemmt where +m_offset in 24x24 block is 18 and n_offset is 16(18x16) +(18x16)_L + + +the region marked with 'x' is computed by following kernel +the region marked with '-' is not computed + + <-- n_off_24 -- > + 16 17 18 19 20 21 22 23 + +ā†‘ 18 x x x - - - - - +| 19 x x x x - - - - +m 20 x x x x x - - - +off 21 x x x x x x - - +24 22 x x x x x x x - +| 23 x x x x x x x x +ā†“ + + +*/ +void bli_dgemmsup_rv_haswell_asm_6x8m_18x16_L + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_7); + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // Query the panel stride of A and convert it to units of bytes. + uint64_t ps_a8 = bli_auxinfo_ps_a( data ) * sizeof( double ); + + + // ------------------------------------------------------------------------- + + begin_asm() + + + mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(rs_b), r10) // load rs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + vxorpd(ymm4, ymm4, ymm4) + vmovapd( ymm4, ymm6) + vmovapd( ymm4, ymm8) + vmovapd( ymm4, ymm9) + vmovapd( ymm4, ymm10) + vmovapd( ymm4, ymm11) + vmovapd( ymm4, ymm12) + vmovapd( ymm4, ymm13) + vmovapd( ymm4, ymm14) + vmovapd( ymm4, ymm15) + + mov(var(b), rbx) // load address of b. + mov(r14, rax) // reset rax to current upanel of a. + + + + PREFETCH_C() + + label(.DPOSTPFETCH) // done prefetching c + + + mov(var(ps_a8), rdx) // load ps_a8 + lea(mem(rax, rdx, 1), rdx) // rdx = a + ps_a8 + lea(mem(r9, r9, 2), rcx) // rcx = 3*cs_a; + // use rcx, rdx for prefetching lines + // from next upanel of a. + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + prefetch(0, mem(rdx, 5*8)) + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + // ---------------------------------- iteration 1 + + prefetch(0, mem(rdx, r9, 1, 5*8)) + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + // ---------------------------------- iteration 2 + + prefetch(0, mem(rdx, r9, 2, 5*8)) + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + // ---------------------------------- iteration 3 + + prefetch(0, mem(rdx, rcx, 1, 5*8)) + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(r12, rcx) // reset rcx to current utile of c. + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm6, ymm6) + vmulpd(ymm0, ymm8, ymm8) + vmulpd(ymm0, ymm9, ymm9) + vmulpd(ymm0, ymm10, ymm10) + vmulpd(ymm0, ymm11, ymm11) + vmulpd(ymm0, ymm12, ymm12) + vmulpd(ymm0, ymm13, ymm13) + vmulpd(ymm0, ymm14, ymm14) + vmulpd(ymm0, ymm15, ymm15) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm4) + vmovupd(xmm4, mem(rcx, 0*32)) + vextractf128(imm(1), ymm4, xmm1) + vmovlpd(xmm1, mem(rcx, 2*8)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm6) + vmovupd(ymm6, mem(rcx, 0*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm8) + vmovupd(ymm8, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm9) + vmovlpd(xmm9, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm10) + vmovupd(ymm10, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm11) + vmovupd(xmm11, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm12) + vmovupd(ymm12, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm13) + vmovupd(xmm13, mem(rcx, 1*32)) + vextractf128(imm(1), ymm13, xmm1) + vmovlpd(xmm1, mem(rcx, 1*32+2*8)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm14) + vmovupd(ymm14, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm15) + vmovupd(ymm15, mem(rcx, 1*32)) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + // begin I/O on columns 0-3 + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx ), ymm3, ymm4) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm6) + vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm8) + vfmadd231pd(mem(rcx, rax, 1), ymm3, ymm10) + vmovupd(ymm4, mem(rcx )) + vmovupd(ymm6, mem(rcx, rsi, 1)) + vmovupd(ymm8, mem(rcx, rsi, 2)) + vextractf128(imm(1), ymm10, xmm1) + vmovhpd(xmm10, mem(rcx, rax, 1, 1*8)) + vmovupd(xmm1, mem(rcx, rax, 1, 2*8)) + + lea(mem(rcx, rsi, 4), rcx) + + vunpcklpd(ymm14, ymm12, ymm0) + vunpckhpd(ymm14, ymm12, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vfmadd231pd(mem(rdx ), xmm3, xmm0) + vfmadd231pd(mem(rdx, rsi, 1), xmm3, xmm1) + vfmadd231pd(mem(rdx, rsi, 2), xmm3, xmm2) + vfmadd231pd(mem(rdx, rax, 1), xmm3, xmm4) + vmovupd(xmm0, mem(rdx )) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovupd(xmm4, mem(rdx, rax, 1)) + + lea(mem(rdx, rsi, 4), rdx) + + // begin I/O on columns 4-7 + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx ), ymm3, ymm5) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm7) + vextractf128(imm(1), ymm5, xmm1) + vmovupd(xmm1, mem(rcx, 2*8 )) + vextractf128(imm(1), ymm7, xmm1) + vmovhpd(xmm1, mem(rcx, rsi, 1, 3*8)) + + vunpcklpd(ymm15, ymm13, ymm0) + vunpckhpd(ymm15, ymm13, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vfmadd231pd(mem(rdx ), xmm3, xmm0) + vfmadd231pd(mem(rdx, rsi, 1), xmm3, xmm1) + vfmadd231pd(mem(rdx, rsi, 2), xmm3, xmm2) + vfmadd231pd(mem(rdx, rax, 1), xmm3, xmm4) + vmovupd(xmm0, mem(rdx )) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovhpd(xmm4, mem(rdx, rax, 1, 1*8)) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmovupd(xmm4, mem(rcx, 0*32)) + vextractf128(imm(1), ymm4, xmm1) + vmovlpd(xmm1, mem(rcx, 2*8)) + add(rdi, rcx) + + + vmovupd(ymm6, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovupd(ymm8, mem(rcx, 0*32)) + vmovlpd(xmm9, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovupd(ymm10, mem(rcx, 0*32)) + vmovupd(xmm11, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovupd(ymm12, mem(rcx, 0*32)) + vmovupd(xmm13, mem(rcx, 1*32)) + vextractf128(imm(1), ymm13, xmm1) + vmovlpd(xmm1, mem(rcx, 1*32+2*8)) + add(rdi, rcx) + + + vmovupd(ymm14, mem(rcx, 0*32)) + vmovupd(ymm15, mem(rcx, 1*32)) + //add(rdi, rcx) + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + // begin I/O on columns 0-3 + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx ), ymm3, ymm4) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm6) + vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm8) + vfmadd231pd(mem(rcx, rax, 1), ymm3, ymm10) + vmovupd(ymm4, mem(rcx )) + vmovupd(ymm6, mem(rcx, rsi, 1)) + vmovupd(ymm8, mem(rcx, rsi, 2)) + vextractf128(imm(1), ymm10, xmm1) + vmovhpd(xmm10, mem(rcx, rax, 1, 1*8)) + vmovupd(xmm1, mem(rcx, rax, 1, 2*8)) + + lea(mem(rcx, rsi, 4), rcx) + + vunpcklpd(ymm14, ymm12, ymm0) + vunpckhpd(ymm14, ymm12, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vmovupd(xmm0, mem(rdx )) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovupd(xmm4, mem(rdx, rax, 1)) + + lea(mem(rdx, rsi, 4), rdx) + + // begin I/O on columns 4-7 + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + + vbroadcastsd(mem(rbx), ymm3) + + vextractf128(imm(1), ymm5, xmm1) + vmovupd(xmm1, mem(rcx, 2*8 )) + vextractf128(imm(1), ymm7, xmm1) + vmovhpd(xmm1, mem(rcx, rsi, 1, 3*8)) + + vunpcklpd(ymm15, ymm13, ymm0) + vunpckhpd(ymm15, ymm13, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vmovupd(xmm0, mem(rdx )) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovhpd(xmm4, mem(rdx, rax, 1, 1*8)) + + label(.DDONE) + vzeroupper() + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [ps_a8] "m" (ps_a8), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + + // Handle edge cases in the m dimension, if they exist. + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7); +} +/* + +Following kernel computes the 6x8 block for the Upper vairant(U) of gemmt where +m_offset in 24x24 block is 0 and n_offset is 0(0x0) +(0x0)_U + +the region marked with 'x' is computed by following kernel +the region marked with '-' is not computed + + <-- n_off_24 -- > + 0 1 2 3 4 5 6 7 + +ā†‘ 0 x x x x x x x x +| 1 - x x x x x x x +m 2 - - x x x x x x +off 3 - - - x x x x x +24 4 - - - - x x x x +| 5 - - - - - x x x +ā†“ + + +*/ +void bli_dgemmsup_rv_haswell_asm_6x8m_0x0_U + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_7); + + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + uint64_t ps_a8 = bli_auxinfo_ps_a( data ) * sizeof( double ); + + begin_asm() + + + mov(var(a), r14) + mov(var(rs_a), r8) + mov(var(cs_a), r9) + lea(mem(, r8, 8), r8) + lea(mem(, r9, 8), r9) + + lea(mem(r8, r8, 2), r13) + lea(mem(r8, r8, 4), r15) + + mov(var(rs_b), r10) + lea(mem(, r10, 8), r10) + + mov(var(c), r12) + mov(var(rs_c), rdi) + lea(mem(, rdi, 8), rdi) + + vxorpd(ymm4, ymm4, ymm4) + vmovapd( ymm4, ymm5) + vmovapd( ymm4, ymm6) + vmovapd( ymm4, ymm7) + vmovapd( ymm4, ymm8) + vmovapd( ymm4, ymm9) + vmovapd( ymm4, ymm10) + vmovapd( ymm4, ymm11) + vmovapd( ymm4, ymm13) + vmovapd( ymm4, ymm15) + + mov(var(b), rbx) + mov(r14, rax) + + + + PREFETCH_C() + + label(.DPOSTPFETCH) + + mov(var(ps_a8), rdx) + lea(mem(rax, rdx, 1), rdx) + lea(mem(r9, r9, 2), rcx) + + mov(var(k_iter), rsi) + test(rsi, rsi) + je(.DCONSIDKLEFT) + + //ymm12, ymm14 can be skipped + label(.DLOOPKITER) // MAIN LOOP + //0 + prefetch(0, mem(rdx, 5*8)) + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm1, ymm3, ymm15) + + //1 + prefetch(0, mem(rdx, r9, 1, 5*8)) + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm1, ymm3, ymm15) + //2 + + prefetch(0, mem(rdx, r9, 2, 5*8)) + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm1, ymm3, ymm15) + //3 + prefetch(0, mem(rdx, rcx, 1, 5*8)) + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm1, ymm3, ymm15) + + dec(rsi) + jne(.DLOOPKITER) + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) + test(rsi, rsi) + je(.DPOSTACCUM) + + label(.DLOOPKLEFT) // EDGE LOOP + + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm1, ymm3, ymm15) + + dec(rsi) + jne(.DLOOPKLEFT) + + label(.DPOSTACCUM) + + + + mov(r12, rcx) + mov(var(alpha), rax) + mov(var(beta), rbx) + vbroadcastsd(mem(rax), ymm0) + vbroadcastsd(mem(rbx), ymm3) + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + vmulpd(ymm0, ymm7, ymm7) + vmulpd(ymm0, ymm8, ymm8) + vmulpd(ymm0, ymm9, ymm9) + vmulpd(ymm0, ymm10, ymm10) + vmulpd(ymm0, ymm11, ymm11) + vmulpd(ymm0, ymm12, ymm12) + vmulpd(ymm0, ymm13, ymm13) + vmulpd(ymm0, ymm14, ymm14) + vmulpd(ymm0, ymm15, ymm15) + + mov(var(cs_c), rsi) + lea(mem(, rsi, 8), rsi) + lea(mem(rcx, rdi, 4), rdx) // c + 4*rs_c; + lea(mem(rsi, rsi, 2), rax) // 3*cs_c; + + + vxorpd(ymm0, ymm0, ymm0) + vucomisd(xmm0, xmm3) + je(.DBETAZERO) + + cmp(imm(8), rdi) + jz(.DCOLSTORED) + + label(.DROWSTORED) + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm4) + vmovupd(ymm4, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm5) + vmovupd(ymm5, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm6) + vmovhpd(xmm6, mem(rcx, 0*32+1*8)) + vextractf128(imm(0x1), ymm6, xmm6) + vmovupd(xmm6, mem(rcx, 0*32+2*8)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm7) + vmovupd(ymm7, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm8) + vextractf128(imm(0x1), ymm8, xmm8) + vmovupd(xmm8, mem(rcx, 0*32+2*8)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm9) + vmovupd(ymm9, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm10) + vextractf128(imm(0x1), ymm10, xmm10) + vmovhpd(xmm10, mem(rcx, 0*32+3*8)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm11) + vmovupd(ymm11, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm13) + vmovupd(ymm13, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm15) + vmovhpd(xmm15, mem(rcx, 1*32+1*8)) + vextractf128(imm(0x1), ymm15, xmm15) + vmovupd(xmm15, mem(rcx, 1*32+2*8)) + + + jmp(.DDONE) + + + label(.DCOLSTORED) + + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx ), ymm3, ymm4) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm6) + vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm8) + vfmadd231pd(mem(rcx, rax, 1), ymm3, ymm10) + vmovlpd(xmm4, mem(rcx )) + vmovupd(xmm6, mem(rcx, rsi, 1)) + vmovupd(xmm8, mem(rcx, rsi, 2)) + vextractf128(imm(0x1), ymm8, xmm8) + vmovlpd(xmm8, mem(rcx, rsi, 2, 1*16)) + vmovupd(ymm10, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + + lea(mem(rdx, rsi, 4), rdx) + + // begin I/O on columns 4-7 + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + vperm2f128(imm(0x31), ymm2, ymm0, ymm9) + vperm2f128(imm(0x31), ymm3, ymm1, ymm11) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx ), ymm3, ymm5) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm7) + vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm9) + vfmadd231pd(mem(rcx, rax, 1), ymm3, ymm11) + vmovupd(ymm5, mem(rcx )) + vmovupd(ymm7, mem(rcx, rsi, 1)) + vmovupd(ymm9, mem(rcx, rsi, 2)) + vmovupd(ymm11, mem(rcx, rax, 1)) + + vunpcklpd(ymm15, ymm13, ymm0) + vunpckhpd(ymm15, ymm13, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vfmadd231pd(mem(rdx ), xmm3, xmm0) + vfmadd231pd(mem(rdx, rsi, 1), xmm3, xmm1) + vfmadd231pd(mem(rdx, rsi, 2), xmm3, xmm2) + vfmadd231pd(mem(rdx, rax, 1), xmm3, xmm4) + vmovlpd(xmm0, mem(rdx )) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovupd(xmm4, mem(rdx, rax, 1)) + + jmp(.DDONE) // jump to end. + + label(.DBETAZERO) // if beta zero + + + cmp(imm(8), rdi) + jz(.DCOLSTORBZ) + + label(.DROWSTORBZ) + + vmovupd(ymm4, mem(rcx, 0*32)) + vmovupd(ymm5, mem(rcx, 1*32)) + add(rdi, rcx) + + vmovhpd(xmm6, mem(rcx, 0*32+1*8)) + vextractf128(imm(0x1), ymm6, xmm6) + vmovupd(xmm6, mem(rcx, 0*32+2*8)) + vmovupd(ymm7, mem(rcx, 1*32)) + add(rdi, rcx) + + vextractf128(imm(0x1), ymm8, xmm8) + vmovupd(xmm8, mem(rcx, 0*32+2*8)) + vmovupd(ymm9, mem(rcx, 1*32)) + add(rdi, rcx) + + vextractf128(imm(0x1), ymm10, xmm10) + vmovhpd(xmm10, mem(rcx, 0*32+3*8)) + vmovupd(ymm11, mem(rcx, 1*32)) + add(rdi, rcx) + + vmovupd(ymm13, mem(rcx, 1*32)) + add(rdi, rcx) + + vmovhpd(xmm15, mem(rcx, 1*32+1*8)) + vextractf128(imm(0x1), ymm15, xmm15) + vmovupd(xmm15, mem(rcx, 1*32+2*8)) + + jmp(.DDONE) + + label(.DCOLSTORBZ) + + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vmovlpd(xmm4, mem(rcx )) + vmovupd(xmm6, mem(rcx, rsi, 1)) + vmovupd(xmm8, mem(rcx, rsi, 2)) + vextractf128(imm(0x1), ymm8, xmm8) + vmovlpd(xmm8, mem(rcx, rsi, 2, 1*16)) + vmovupd(ymm10, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + + lea(mem(rdx, rsi, 4), rdx) + + // begin I/O on columns 4-7 + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + vperm2f128(imm(0x31), ymm2, ymm0, ymm9) + vperm2f128(imm(0x31), ymm3, ymm1, ymm11) + + vmovupd(ymm5, mem(rcx )) + vmovupd(ymm7, mem(rcx, rsi, 1)) + vmovupd(ymm9, mem(rcx, rsi, 2)) + vmovupd(ymm11, mem(rcx, rax, 1)) + + vunpcklpd(ymm15, ymm13, ymm0) + vunpckhpd(ymm15, ymm13, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vmovlpd(xmm0, mem(rdx )) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovupd(xmm4, mem(rdx, rax, 1)) + + + label(.DDONE) + vzeroupper() + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [ps_a8] "m" (ps_a8), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +/* + +Following kernel computes the 6x8 block for the Upper vairant(U) of gemmt where +m_offset in 24x24 block is 6 and n_offset is 8(6x8) +(6x8)_U + +the region marked with 'x' is computed by following kernel +the region marked with '-' is not computed + + <-- n_off_24 -- > + 8 9 10 11 12 13 14 15 + +ā†‘ 6 x x x x x x x x +| 7 x x x x x x x x +m 8 x x x x x x x x +off 9 - x x x x x x x +24 10 - - x x x x x x +| 11 - - - x x x x x +ā†“ + + +*/ +void bli_dgemmsup_rv_haswell_asm_6x8m_6x8_U + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_7); + + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + uint64_t ps_a8 = bli_auxinfo_ps_a( data ) * sizeof( double ); + + begin_asm() + + + mov(var(a), r14) + mov(var(rs_a), r8) + mov(var(cs_a), r9) + lea(mem(, r8, 8), r8) + lea(mem(, r9, 8), r9) + + lea(mem(r8, r8, 2), r13) + lea(mem(r8, r8, 4), r15) + + mov(var(rs_b), r10) + lea(mem(, r10, 8), r10) + + mov(var(c), r12) + mov(var(rs_c), rdi) + lea(mem(, rdi, 8), rdi) + + vxorpd(ymm4, ymm4, ymm4) + vmovapd( ymm4, ymm5) + vmovapd( ymm4, ymm6) + vmovapd( ymm4, ymm7) + vmovapd( ymm4, ymm8) + vmovapd( ymm4, ymm9) + vmovapd( ymm4, ymm10) + vmovapd( ymm4, ymm11) + vmovapd( ymm4, ymm12) + vmovapd( ymm4, ymm13) + vmovapd( ymm4, ymm14) + vmovapd( ymm4, ymm15) + + mov(var(b), rbx) + mov(r14, rax) + + + PREFETCH_C() + + label(.DPOSTPFETCH) + + mov(var(ps_a8), rdx) + lea(mem(rax, rdx, 1), rdx) + lea(mem(r9, r9, 2), rcx) + + mov(var(k_iter), rsi) + test(rsi, rsi) + je(.DCONSIDKLEFT) + + label(.DLOOPKITER) // MAIN LOOP + //0 + prefetch(0, mem(rdx, 5*8)) + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + //1 + prefetch(0, mem(rdx, r9, 1, 5*8)) + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + //2 + + prefetch(0, mem(rdx, r9, 2, 5*8)) + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + //3 + prefetch(0, mem(rdx, rcx, 1, 5*8)) + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + dec(rsi) + jne(.DLOOPKITER) + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) + test(rsi, rsi) + je(.DPOSTACCUM) + + label(.DLOOPKLEFT) // EDGE LOOP + + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + dec(rsi) + jne(.DLOOPKLEFT) + + label(.DPOSTACCUM) + + + + mov(r12, rcx) + mov(var(alpha), rax) + mov(var(beta), rbx) + vbroadcastsd(mem(rax), ymm0) + vbroadcastsd(mem(rbx), ymm3) + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + vmulpd(ymm0, ymm7, ymm7) + vmulpd(ymm0, ymm8, ymm8) + vmulpd(ymm0, ymm9, ymm9) + vmulpd(ymm0, ymm10, ymm10) + vmulpd(ymm0, ymm11, ymm11) + vmulpd(ymm0, ymm12, ymm12) + vmulpd(ymm0, ymm13, ymm13) + vmulpd(ymm0, ymm14, ymm14) + vmulpd(ymm0, ymm15, ymm15) + + mov(var(cs_c), rsi) + lea(mem(, rsi, 8), rsi) + lea(mem(rcx, rdi, 4), rdx) // c + 4*rs_c; + lea(mem(rsi, rsi, 2), rax) // 3*cs_c; + + + vxorpd(ymm0, ymm0, ymm0) + vucomisd(xmm0, xmm3) + je(.DBETAZERO) + + cmp(imm(8), rdi) + jz(.DCOLSTORED) + + label(.DROWSTORED) + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm4) + vmovupd(ymm4, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm5) + vmovupd(ymm5, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm6) + vmovupd(ymm6, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm7) + vmovupd(ymm7, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm8) + vmovupd(ymm8, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm9) + vmovupd(ymm9, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm10) + vmovhpd(xmm10, mem(rcx, 0*32+1*8)) + vextractf128(imm(0x1), ymm10, xmm10) + vmovupd(xmm10, mem(rcx, 0*32+2*8)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm11) + vmovupd(ymm11, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm12) + vextractf128(imm(0x1), ymm12, xmm12) + vmovupd(xmm12, mem(rcx, 0*32+2*8)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm13) + vmovupd(ymm13, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm14) + vextractf128(imm(0x1), ymm14, xmm14) + vmovhpd(xmm14, mem(rcx, 0*32+3*8)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm15) + vmovupd(ymm15, mem(rcx, 1*32)) + + jmp(.DDONE) + + + label(.DCOLSTORED) + + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx ), ymm3, ymm4) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm6) + vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm8) + vfmadd231pd(mem(rcx, rax, 1), ymm3, ymm10) + vmovupd(xmm4, mem(rcx )) + vextractf128(imm(0x1), ymm4, xmm4) + vmovlpd(xmm4, mem(rcx, 2*8 )) + vmovupd(ymm6, mem(rcx, rsi, 1)) + vmovupd(ymm8, mem(rcx, rsi, 2)) + vmovupd(ymm10, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + vunpcklpd(ymm14, ymm12, ymm0) + vunpckhpd(ymm14, ymm12, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vfmadd231pd(mem(rdx ), xmm3, xmm0) + vfmadd231pd(mem(rdx, rsi, 1), xmm3, xmm1) + vfmadd231pd(mem(rdx, rsi, 2), xmm3, xmm2) + vfmadd231pd(mem(rdx, rax, 1), xmm3, xmm4) + vmovlpd(xmm2, mem(rdx, rsi, 2)) + vmovupd(xmm4, mem(rdx, rax, 1)) + + lea(mem(rdx, rsi, 4), rdx) + + // begin I/O on columns 4-7 + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + vperm2f128(imm(0x31), ymm2, ymm0, ymm9) + vperm2f128(imm(0x31), ymm3, ymm1, ymm11) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx ), ymm3, ymm5) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm7) + vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm9) + vfmadd231pd(mem(rcx, rax, 1), ymm3, ymm11) + vmovupd(ymm5, mem(rcx )) + vmovupd(ymm7, mem(rcx, rsi, 1)) + vmovupd(ymm9, mem(rcx, rsi, 2)) + vmovupd(ymm11, mem(rcx, rax, 1)) + + vunpcklpd(ymm15, ymm13, ymm0) + vunpckhpd(ymm15, ymm13, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vfmadd231pd(mem(rdx ), xmm3, xmm0) + vfmadd231pd(mem(rdx, rsi, 1), xmm3, xmm1) + vfmadd231pd(mem(rdx, rsi, 2), xmm3, xmm2) + vfmadd231pd(mem(rdx, rax, 1), xmm3, xmm4) + vmovupd(xmm0, mem(rdx )) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovupd(xmm4, mem(rdx, rax, 1)) + + jmp(.DDONE) // jump to end. + + label(.DBETAZERO) // if beta zero + + + cmp(imm(8), rdi) + jz(.DCOLSTORBZ) + + label(.DROWSTORBZ) + + vmovupd(ymm4, mem(rcx, 0*32)) + vmovupd(ymm5, mem(rcx, 1*32)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx, 0*32)) + vmovupd(ymm7, mem(rcx, 1*32)) + add(rdi, rcx) + + vmovupd(ymm8, mem(rcx, 0*32)) + vmovupd(ymm9, mem(rcx, 1*32)) + add(rdi, rcx) + + vmovhpd(xmm10, mem(rcx, 0*32+1*8)) + vextractf128(imm(0x1), ymm10, xmm10) + vmovupd(xmm10, mem(rcx, 0*32+2*8)) + vmovupd(ymm11, mem(rcx, 1*32)) + add(rdi, rcx) + + + vextractf128(imm(0x1), ymm12, xmm12) + vmovupd(xmm12, mem(rcx, 0*32+2*8)) + vmovupd(ymm13, mem(rcx, 1*32)) + add(rdi, rcx) + + vextractf128(imm(0x1), ymm14, xmm14) + vmovhpd(xmm14, mem(rcx, 0*32+3*8)) + vmovupd(ymm15, mem(rcx, 1*32)) + + jmp(.DDONE) + + label(.DCOLSTORBZ) + + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vmovupd(xmm4, mem(rcx )) + vextractf128(imm(0x1), ymm4, xmm4) + vmovlpd(xmm4, mem(rcx, 2*8 )) + vmovupd(ymm6, mem(rcx, rsi, 1)) + vmovupd(ymm8, mem(rcx, rsi, 2)) + vmovupd(ymm10, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + vunpcklpd(ymm14, ymm12, ymm0) + vunpckhpd(ymm14, ymm12, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vmovlpd(xmm2, mem(rdx, rsi, 2)) + vmovupd(xmm4, mem(rdx, rax, 1)) + + lea(mem(rdx, rsi, 4), rdx) + + // begin I/O on columns 4-7 + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + vperm2f128(imm(0x31), ymm2, ymm0, ymm9) + vperm2f128(imm(0x31), ymm3, ymm1, ymm11) + + vmovupd(ymm5, mem(rcx )) + vmovupd(ymm7, mem(rcx, rsi, 1)) + vmovupd(ymm9, mem(rcx, rsi, 2)) + vmovupd(ymm11, mem(rcx, rax, 1)) + + vunpcklpd(ymm15, ymm13, ymm0) + vunpckhpd(ymm15, ymm13, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vmovupd(xmm0, mem(rdx )) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovupd(xmm4, mem(rdx, rax, 1)) + + + label(.DDONE) + vzeroupper() + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [ps_a8] "m" (ps_a8), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +/* + +Following kernel computes the 6x8 block for the Upper vairant(U) of gemmt where +m_offset in 24x24 block is 12 and n_offset is 16(12x16) +(12x16)_U + + +the region marked with 'x' is computed by following kernel +the region marked with '-' is not computed + + <-- n_off_24 -- > + 16 17 18 19 20 21 22 23 + +ā†‘ 12 x x x x x x x x +| 13 x x x x x x x x +m 14 x x x x x x x x +off 15 x x x x x x x x +24 16 x x x x x x x x +| 17 - x x x x x x x +ā†“ + + +*/ +void bli_dgemmsup_rv_haswell_asm_6x8m_12x16_U + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_7); + + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + uint64_t ps_a8 = bli_auxinfo_ps_a( data ) * sizeof( double ); + + begin_asm() + + + mov(var(a), r14) + mov(var(rs_a), r8) + mov(var(cs_a), r9) + lea(mem(, r8, 8), r8) + lea(mem(, r9, 8), r9) + + lea(mem(r8, r8, 2), r13) + lea(mem(r8, r8, 4), r15) + + mov(var(rs_b), r10) + lea(mem(, r10, 8), r10) + + mov(var(c), r12) + mov(var(rs_c), rdi) + lea(mem(, rdi, 8), rdi) + + vxorpd(ymm4, ymm4, ymm4) + vmovapd( ymm4, ymm5) + vmovapd( ymm4, ymm6) + vmovapd( ymm4, ymm7) + vmovapd( ymm4, ymm8) + vmovapd( ymm4, ymm9) + vmovapd( ymm4, ymm10) + vmovapd( ymm4, ymm11) + vmovapd( ymm4, ymm12) + vmovapd( ymm4, ymm13) + vmovapd( ymm4, ymm14) + vmovapd( ymm4, ymm15) + + mov(var(b), rbx) + mov(r14, rax) + + + PREFETCH_C() + + label(.DPOSTPFETCH) + + mov(var(ps_a8), rdx) + lea(mem(rax, rdx, 1), rdx) + lea(mem(r9, r9, 2), rcx) + + mov(var(k_iter), rsi) + test(rsi, rsi) + je(.DCONSIDKLEFT) + + label(.DLOOPKITER) // MAIN LOOP + //0 + prefetch(0, mem(rdx, 5*8)) + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + //1 + prefetch(0, mem(rdx, r9, 1, 5*8)) + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + //2 + + prefetch(0, mem(rdx, r9, 2, 5*8)) + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + //3 + prefetch(0, mem(rdx, rcx, 1, 5*8)) + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + dec(rsi) + jne(.DLOOPKITER) + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) + test(rsi, rsi) + je(.DPOSTACCUM) + + label(.DLOOPKLEFT) // EDGE LOOP + + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + dec(rsi) + jne(.DLOOPKLEFT) + + label(.DPOSTACCUM) + + + + mov(r12, rcx) + mov(var(alpha), rax) + mov(var(beta), rbx) + vbroadcastsd(mem(rax), ymm0) + vbroadcastsd(mem(rbx), ymm3) + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + vmulpd(ymm0, ymm7, ymm7) + vmulpd(ymm0, ymm8, ymm8) + vmulpd(ymm0, ymm9, ymm9) + vmulpd(ymm0, ymm10, ymm10) + vmulpd(ymm0, ymm11, ymm11) + vmulpd(ymm0, ymm12, ymm12) + vmulpd(ymm0, ymm13, ymm13) + vmulpd(ymm0, ymm14, ymm14) + vmulpd(ymm0, ymm15, ymm15) + + mov(var(cs_c), rsi) + lea(mem(, rsi, 8), rsi) + lea(mem(rcx, rdi, 4), rdx) // c + 4*rs_c; + lea(mem(rsi, rsi, 2), rax) // 3*cs_c; + + + vxorpd(ymm0, ymm0, ymm0) + vucomisd(xmm0, xmm3) + je(.DBETAZERO) + + cmp(imm(8), rdi) + jz(.DCOLSTORED) + + label(.DROWSTORED) + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm4) + vmovupd(ymm4, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm5) + vmovupd(ymm5, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm6) + vmovupd(ymm6, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm7) + vmovupd(ymm7, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm8) + vmovupd(ymm8, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm9) + vmovupd(ymm9, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm10) + vmovupd(ymm10, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm11) + vmovupd(ymm11, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm12) + vmovupd(ymm12, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm13) + vmovupd(ymm13, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm14) + vmovhpd(xmm14, mem(rcx, 0*32+1*8)) + vextractf128(imm(0x1), ymm14, xmm14) + vmovupd(xmm14, mem(rcx, 0*32+2*8)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm15) + vmovupd(ymm15, mem(rcx, 1*32)) + + jmp(.DDONE) + + + label(.DCOLSTORED) + + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx ), ymm3, ymm4) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm6) + vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm8) + vfmadd231pd(mem(rcx, rax, 1), ymm3, ymm10) + vmovupd(ymm4, mem(rcx )) + vmovupd(ymm6, mem(rcx, rsi, 1)) + vmovupd(ymm8, mem(rcx, rsi, 2)) + vmovupd(ymm10, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + vunpcklpd(ymm14, ymm12, ymm0) + vunpckhpd(ymm14, ymm12, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vfmadd231pd(mem(rdx ), xmm3, xmm0) + vfmadd231pd(mem(rdx, rsi, 1), xmm3, xmm1) + vfmadd231pd(mem(rdx, rsi, 2), xmm3, xmm2) + vfmadd231pd(mem(rdx, rax, 1), xmm3, xmm4) + vmovlpd(xmm0, mem(rdx )) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovupd(xmm4, mem(rdx, rax, 1)) + + lea(mem(rdx, rsi, 4), rdx) + + // begin I/O on columns 4-7 + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + vperm2f128(imm(0x31), ymm2, ymm0, ymm9) + vperm2f128(imm(0x31), ymm3, ymm1, ymm11) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx ), ymm3, ymm5) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm7) + vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm9) + vfmadd231pd(mem(rcx, rax, 1), ymm3, ymm11) + vmovupd(ymm5, mem(rcx )) + vmovupd(ymm7, mem(rcx, rsi, 1)) + vmovupd(ymm9, mem(rcx, rsi, 2)) + vmovupd(ymm11, mem(rcx, rax, 1)) + + vunpcklpd(ymm15, ymm13, ymm0) + vunpckhpd(ymm15, ymm13, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vfmadd231pd(mem(rdx ), xmm3, xmm0) + vfmadd231pd(mem(rdx, rsi, 1), xmm3, xmm1) + vfmadd231pd(mem(rdx, rsi, 2), xmm3, xmm2) + vfmadd231pd(mem(rdx, rax, 1), xmm3, xmm4) + vmovupd(xmm0, mem(rdx )) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovupd(xmm4, mem(rdx, rax, 1)) + + jmp(.DDONE) // jump to end. + + label(.DBETAZERO) + + + cmp(imm(8), rdi) + jz(.DCOLSTORBZ) + + label(.DROWSTORBZ) + + vmovupd(ymm4, mem(rcx, 0*32)) + vmovupd(ymm5, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovupd(ymm6, mem(rcx, 0*32)) + vmovupd(ymm7, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovupd(ymm8, mem(rcx, 0*32)) + vmovupd(ymm9, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovupd(ymm10, mem(rcx, 0*32)) + vmovupd(ymm11, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovupd(ymm12, mem(rcx, 0*32)) + vmovupd(ymm13, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovhpd(xmm14, mem(rcx, 0*32+1*8)) + vextractf128(imm(0x1), ymm14, xmm14) + vmovupd(xmm14, mem(rcx, 0*32+2*8)) + vmovupd(ymm15, mem(rcx, 1*32)) + + jmp(.DDONE) + + label(.DCOLSTORBZ) + + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vmovupd(ymm4, mem(rcx )) + vmovupd(ymm6, mem(rcx, rsi, 1)) + vmovupd(ymm8, mem(rcx, rsi, 2)) + vmovupd(ymm10, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + vunpcklpd(ymm14, ymm12, ymm0) + vunpckhpd(ymm14, ymm12, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vmovlpd(xmm0, mem(rdx )) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovupd(xmm4, mem(rdx, rax, 1)) + + lea(mem(rdx, rsi, 4), rdx) + + // begin I/O on columns 4-7 + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + vperm2f128(imm(0x31), ymm2, ymm0, ymm9) + vperm2f128(imm(0x31), ymm3, ymm1, ymm11) + + vmovupd(ymm5, mem(rcx )) + vmovupd(ymm7, mem(rcx, rsi, 1)) + vmovupd(ymm9, mem(rcx, rsi, 2)) + vmovupd(ymm11, mem(rcx, rax, 1)) + + vunpcklpd(ymm15, ymm13, ymm0) + vunpckhpd(ymm15, ymm13, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vmovupd(xmm0, mem(rdx )) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovupd(xmm4, mem(rdx, rax, 1)) + + + label(.DDONE) + vzeroupper() + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [ps_a8] "m" (ps_a8), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +/* + +Following kernel computes the 6x8 block for the Upper vairant(U) of gemmt where +m_offset in 24x24 block is 6 and n_offset is 0(6x0) +(6x0)_U + + +the region marked with 'x' is computed by following kernel +the region marked with '-' is not computed + + <-- n_off_24 -- > + 0 1 2 3 4 5 6 7 + +ā†‘ 6 - - - - - - x x +| 7 - - - - - - - x +m 8 - - - - - - - - +off 9 - - - - - - - - +24 10 - - - - - - - - +| 11 - - - - - - - - +ā†“ + + +*/ +void bli_dgemmsup_rv_haswell_asm_6x8m_6x0_U + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + uint64_t ps_a8 = bli_auxinfo_ps_a( data ) * sizeof( double ); + + begin_asm() + mov(var(a), r14) + mov(var(b), rbx) + mov(var(c), r12) + mov(r14, rax) + + mov(var(rs_a), r8) + mov(var(cs_a), r9) + lea(mem(, r8, 8), r8) + lea(mem(, r9, 8), r9) + + mov(var(rs_b), r10) + lea(mem(, r10, 8), r10) + + mov(var(rs_c), rdi) + lea(mem(, rdi, 8), rdi) + + lea(mem(r8, r8, 4), r15) + + vxorpd(ymm5, ymm5, ymm5) + vxorpd(ymm7, ymm7, ymm7) + + cmp(imm(8), rdi) + jz(.DCOLPFETCH) + + label(.DROWPFETCH) + lea(mem(r12, rdi, 2), rdx) + lea(mem(rdx, rdi, 1), rdx) + prefetch(0, mem(rdx, rdi, 1, 1*8)) + prefetch(0, mem(rdx, rdi, 2, 2*8)) + jmp(.DPOSTPFETCH) + + label(.DCOLPFETCH) + mov(var(cs_c), rsi) + lea(mem(, rsi, 8), rsi) + prefetch(0, mem(r12, 5*8)) + prefetch(0, mem(r12, rsi, 1, 5*8)) + + label(.DPOSTPFETCH) + mov(var(k_iter), rsi) + test(rsi, rsi) + lea(mem(rbx, 1*16), rbx) + je(.DCONSILEFT) + + //compute xmm5 and xmm7 only + label(.DMAIN) + //0 + vmovupd(mem(rbx, 1*32), xmm1) + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(xmm1, xmm2, xmm5) + vfmadd231pd(xmm1, xmm3, xmm7) + add(r10, rbx) + add(r9, rax) + //1 + vmovupd(mem(rbx, 1*32), xmm1) + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(xmm1, xmm2, xmm5) + vfmadd231pd(xmm1, xmm3, xmm7) + add(r10, rbx) + add(r9, rax) + //2 + vmovupd(mem(rbx, 1*32), xmm1) + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(xmm1, xmm2, xmm5) + vfmadd231pd(xmm1, xmm3, xmm7) + add(r10, rbx) + add(r9, rax) + //3 + vmovupd(mem(rbx, 1*32), xmm1) + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(xmm1, xmm2, xmm5) + vfmadd231pd(xmm1, xmm3, xmm7) + add(r10, rbx) + add(r9, rax) + + dec(rsi) + jne(.DMAIN) + + label(.DCONSILEFT) + mov(var(k_left), rsi) + test(rsi, rsi) + je(.DPOSTACC) + + label(.DLEFT) + vmovupd(mem(rbx, 1*32), xmm1) + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(xmm1, xmm2, xmm5) + vfmadd231pd(xmm1, xmm3, xmm7) + add(r10, rbx) + add(r9, rax) + dec(rsi) + jne(.DLEFT) + + label(.DPOSTACC) + mov(r12, rcx) + mov(var(alpha), rax) + mov(var(beta), rbx) + vbroadcastsd(mem(rax), ymm0) + vbroadcastsd(mem(rbx), ymm3) + lea(mem(rsi, rsi, 2), rax) + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm7, ymm7) + + mov(var(cs_c), rsi) + lea(mem(, rsi, 8), rsi) + vxorpd(ymm0, ymm0, ymm0) + vucomisd(xmm0, xmm3) + je(.DBETAZERO) + + cmp(imm(8), rdi) + je(.DCOLSTOR) + + label(.DROWSTOR) + lea(mem(rcx, 1*32), rcx) + lea(mem(rcx, 1*16), rcx) + + vfmadd231pd(mem(rcx, 0*32), xmm3, xmm5) + vmovlpd(xmm5, mem(rcx)) + vmovhpd(xmm5, mem(rcx, rsi, 1)) + add(rdi, rcx) + vfmadd231pd(mem(rcx, 0*32), xmm3, xmm7) + vmovhpd(xmm7, mem(rcx, rsi, 1)) + + jmp(.DDONE) + + label(.DCOLSTOR) + + vbroadcastsd(mem(rbx), ymm3) + lea(mem(rcx, rsi, 4), rcx) + lea(mem(rcx, rsi, 2), rcx) + vunpcklpd(xmm7, xmm5, xmm0) + vunpckhpd(xmm7, xmm5, xmm1) + vfmadd231pd(mem(rcx ), xmm3, xmm0) + vfmadd231pd(mem(rcx, rsi, 1), xmm3, xmm1) + vmovlpd(xmm0, mem(rcx )) + vmovupd(xmm1, mem(rcx, rsi, 1)) + jmp(.DDONE) + + label(.DBETAZERO) + cmp(imm(8), rdi) + je(.DCOLSTORBZ) + + label(.DROWSTORBZ) + lea(mem(rcx, 1*32), rcx) + lea(mem(rcx, 1*16), rcx) + + vmovlpd(xmm5, mem(rcx)) + vmovhpd(xmm5, mem(rcx, rsi, 1)) + add(rdi, rcx) + vmovhpd(xmm7, mem(rcx, rsi, 1)) + + jmp(.DDONE) + + label(.DCOLSTORBZ) + + lea(mem(rcx, rsi, 4), rcx) + lea(mem(rcx, rsi, 2), rcx) + vunpcklpd(xmm7, xmm5, xmm0) + vunpckhpd(xmm7, xmm5, xmm1) + vmovlpd(xmm0, mem(rcx )) + vmovupd(xmm1, mem(rcx, rsi, 1)) + jmp(.DDONE) + + + label(.DDONE) + vzeroupper() + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [ps_a8] "m" (ps_a8), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +/* + +Following kernel computes the 6x8 block for the Upper vairant(U) of gemmt where +m_offset in 24x24 block is 12 and n_offset is 8(12x8) +(12x8)_U + +the region marked with 'x' is computed by following kernel +the region marked with '-' is not computed + + <-- n_off_24 -- > + 8 9 10 11 12 13 14 15 + +ā†‘ 12 - - - - x x x x +| 13 - - - - - x x x +m 14 - - - - - - x x +off 15 - - - - - - - x +24 16 - - - - - - - - +| 17 - - - - - - - - +ā†“ + + +*/ +void bli_dgemmsup_rv_haswell_asm_6x8m_12x8_U + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_7); + + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + uint64_t ps_a8 = bli_auxinfo_ps_a( data ) * sizeof( double ); + + begin_asm() + + + mov(var(a), r14) + mov(var(rs_a), r8) + mov(var(cs_a), r9) + lea(mem(, r8, 8), r8) + lea(mem(, r9, 8), r9) + + lea(mem(r8, r8, 2), r13) + lea(mem(r8, r8, 4), r15) + + mov(var(rs_b), r10) + lea(mem(, r10, 8), r10) + + mov(var(c), r12) + mov(var(rs_c), rdi) + lea(mem(, rdi, 8), rdi) + + vxorpd(ymm5, ymm5, ymm5) + vmovapd( ymm5, ymm7) + vmovapd( ymm5, ymm9) + vmovapd( ymm5, ymm11) + + mov(var(b), rbx) + mov(r14, rax) + + cmp(imm(8), rdi) + jz(.DCOLPFETCH) + label(.DROWPFETCH) + + lea(mem(r12, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(r12, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(r12, rdi, 1, 7*8)) // prefetch c + 1*rs_c + prefetch(0, mem(r12, rdi, 2, 7*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 7*8)) // prefetch c + 3*rs_c + + jmp(.DPOSTPFETCH) + label(.DCOLPFETCH) + + mov(var(cs_c), rsi) + lea(mem(, rsi, 8), rsi) + lea(mem(r12, rsi, 2), rdx) + lea(mem(rdx, rsi, 1), rdx) + prefetch(0, mem(rdx, rsi, 1, 5*8)) // prefetch c + 4*cs_c + prefetch(0, mem(rdx, rsi, 2, 5*8)) // prefetch c + 5*cs_c + lea(mem(rdx, rsi, 2), rdx) // rdx = c + 5*cs_c; + prefetch(0, mem(rdx, rsi, 1, 5*8)) // prefetch c + 6*cs_c + prefetch(0, mem(rdx, rsi, 2, 5*8)) // prefetch c + 7*cs_c + + label(.DPOSTPFETCH) + + mov(var(ps_a8), rdx) + lea(mem(rax, rdx, 1), rdx) + lea(mem(r9, r9, 2), rcx) + + mov(var(k_iter), rsi) + test(rsi, rsi) + je(.DCONSIDKLEFT) + + //compute ymm5, 7, 9, 11 only + label(.DLOOPKITER) // MAIN LOOP + //0 + prefetch(0, mem(rdx, 5*8)) + + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm1, ymm3, ymm11) + //1 + prefetch(0, mem(rdx, r9, 1, 5*8)) + + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm1, ymm3, ymm11) + //2 + + prefetch(0, mem(rdx, r9, 2, 5*8)) + + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm1, ymm3, ymm11) + //3 + prefetch(0, mem(rdx, rcx, 1, 5*8)) + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; + + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm1, ymm3, ymm11) + + dec(rsi) + jne(.DLOOPKITER) + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) + test(rsi, rsi) + je(.DPOSTACCUM) + + label(.DLOOPKLEFT) // EDGE LOOP + + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) + + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm1, ymm3, ymm11) + + dec(rsi) + jne(.DLOOPKLEFT) + + label(.DPOSTACCUM) + + + + mov(r12, rcx) + mov(var(alpha), rax) + mov(var(beta), rbx) + vbroadcastsd(mem(rax), ymm0) + vbroadcastsd(mem(rbx), ymm3) + + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm7, ymm7) + vmulpd(ymm0, ymm9, ymm9) + vmulpd(ymm0, ymm11, ymm11) + + mov(var(cs_c), rsi) + lea(mem(, rsi, 8), rsi) + lea(mem(rcx, rdi, 4), rdx) // c + 4*rs_c; + lea(mem(rsi, rsi, 2), rax) // 3*cs_c; + + + vxorpd(ymm0, ymm0, ymm0) + vucomisd(xmm0, xmm3) + je(.DBETAZERO) + + cmp(imm(8), rdi) + jz(.DCOLSTORED) + + label(.DROWSTORED) + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm5) + vmovupd(ymm5, mem(rcx, 1*32)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm7) + vmovhpd(xmm7, mem(rcx, 1*32+1*8)) + vextractf128(imm(0x1), ymm7, xmm7) + vmovupd(xmm7, mem(rcx, 1*32+2*8)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm9) + vextractf128(imm(0x1), ymm9, xmm9) + vmovupd(xmm9, mem(rcx, 1*32+2*8)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm11) + vextractf128(imm(0x1), ymm11, xmm11) + vmovhpd(xmm11, mem(rcx, 1*32+3*8)) + + + jmp(.DDONE) + + + label(.DCOLSTORED) + + lea(mem(rdx, rsi, 4), rdx) + lea(mem(rcx, rsi, 4), rcx) + + // begin I/O on columns 4-7 + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + vperm2f128(imm(0x31), ymm2, ymm0, ymm9) + vperm2f128(imm(0x31), ymm3, ymm1, ymm11) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx ), ymm3, ymm5) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm7) + vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm9) + vfmadd231pd(mem(rcx, rax, 1), ymm3, ymm11) + vmovlpd(xmm5, mem(rcx )) + vmovupd(xmm7, mem(rcx, rsi, 1)) + vmovupd(xmm9, mem(rcx, rsi, 2)) + vextractf128(imm(0x1), ymm9, xmm9) + vmovlpd(xmm9, mem(rcx, rsi, 2, 1*16)) + vmovupd(ymm11, mem(rcx, rax, 1)) + + jmp(.DDONE) // jump to end. + + label(.DBETAZERO) + + + cmp(imm(8), rdi) + jz(.DCOLSTORBZ) + + label(.DROWSTORBZ) + + vmovupd(ymm5, mem(rcx, 1*32)) + add(rdi, rcx) + + vmovhpd(xmm7, mem(rcx, 1*32+1*8)) + vextractf128(imm(0x1), ymm7, xmm7) + vmovupd(xmm7, mem(rcx, 1*32+2*8)) + add(rdi, rcx) + + vextractf128(imm(0x1), ymm9, xmm9) + vmovupd(xmm9, mem(rcx, 1*32+2*8)) + add(rdi, rcx) + + vextractf128(imm(0x1), ymm11, xmm11) + vmovhpd(xmm11, mem(rcx, 1*32+3*8)) + + jmp(.DDONE) + + label(.DCOLSTORBZ) + + lea(mem(rdx, rsi, 4), rdx) + lea(mem(rcx, rsi, 4), rcx) + + // begin I/O on columns 4-7 + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + vperm2f128(imm(0x31), ymm2, ymm0, ymm9) + vperm2f128(imm(0x31), ymm3, ymm1, ymm11) + + vmovlpd(xmm5, mem(rcx )) + vmovupd(xmm7, mem(rcx, rsi, 1)) + vmovupd(xmm9, mem(rcx, rsi, 2)) + vextractf128(imm(0x1), ymm9, xmm9) + vmovlpd(xmm9, mem(rcx, rsi, 2, 2*8)) + vmovupd(ymm11, mem(rcx, rax, 1)) + + label(.DDONE) + vzeroupper() + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [ps_a8] "m" (ps_a8), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +/* + +Following kernel computes the 6x8 block for the Upper vairant(U) of gemmt where +m_offset in 24x24 block is 18 and n_offset is 16(18x16) +(18x16)_U + + +the region marked with 'x' is computed by following kernel +the region marked with '-' is not computed + + <-- n_off_24 -- > + 16 17 18 19 20 21 22 23 + +ā†‘ 18 - - x x x x x x +| 19 - - - x x x x x +m 20 - - - - x x x x +off 21 - - - - - x x x +24 22 - - - - - - x x +| 23 - - - - - - - x +ā†“ + + +*/ +void bli_dgemmsup_rv_haswell_asm_6x8m_18x16_U + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_7); + + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + uint64_t ps_a8 = bli_auxinfo_ps_a( data ) * sizeof( double ); + + begin_asm() + + + mov(var(a), r14) + mov(var(rs_a), r8) + mov(var(cs_a), r9) + lea(mem(, r8, 8), r8) + lea(mem(, r9, 8), r9) + + lea(mem(r8, r8, 2), r13) + lea(mem(r8, r8, 4), r15) + + mov(var(rs_b), r10) + lea(mem(, r10, 8), r10) + + mov(var(c), r12) + mov(var(rs_c), rdi) + lea(mem(, rdi, 8), rdi) + + vxorpd(ymm4, ymm4, ymm4) + vmovapd( ymm4, ymm5) + vmovapd( ymm4, ymm6) + vmovapd( ymm4, ymm7) + vmovapd( ymm4, ymm9) + vmovapd( ymm4, ymm11) + vmovapd( ymm4, ymm13) + vmovapd( ymm4, ymm15) + + mov(var(b), rbx) + mov(r14, rax) + + + + PREFETCH_C() + + label(.DPOSTPFETCH) + + mov(var(ps_a8), rdx) + lea(mem(rax, rdx, 1), rdx) + lea(mem(r9, r9, 2), rcx) + + mov(var(k_iter), rsi) + test(rsi, rsi) + je(.DCONSIDKLEFT) + + //skip ymm8, 10, 12, 14 + label(.DLOOPKITER) // MAIN LOOP + //0 + prefetch(0, mem(rdx, 5*8)) + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm1, ymm3, ymm15) + + //1 + prefetch(0, mem(rdx, r9, 1, 5*8)) + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm1, ymm3, ymm15) + + //2 + prefetch(0, mem(rdx, r9, 2, 5*8)) + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm1, ymm3, ymm15) + + //3 + prefetch(0, mem(rdx, rcx, 1, 5*8)) + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm1, ymm3, ymm15) + + dec(rsi) + jne(.DLOOPKITER) + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) + test(rsi, rsi) + je(.DPOSTACCUM) + + label(.DLOOPKLEFT) // EDGE LOOP + + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm1, ymm3, ymm15) + + dec(rsi) + jne(.DLOOPKLEFT) + + label(.DPOSTACCUM) + + + + mov(r12, rcx) + mov(var(alpha), rax) + mov(var(beta), rbx) + vbroadcastsd(mem(rax), ymm0) + vbroadcastsd(mem(rbx), ymm3) + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + vmulpd(ymm0, ymm7, ymm7) + vmulpd(ymm0, ymm9, ymm9) + vmulpd(ymm0, ymm11, ymm11) + vmulpd(ymm0, ymm13, ymm13) + vmulpd(ymm0, ymm15, ymm15) + + mov(var(cs_c), rsi) + lea(mem(, rsi, 8), rsi) + lea(mem(rcx, rdi, 4), rdx) // c + 4*rs_c; + lea(mem(rsi, rsi, 2), rax) // 3*cs_c; + + + vxorpd(ymm0, ymm0, ymm0) + vucomisd(xmm0, xmm3) + je(.DBETAZERO) + + cmp(imm(8), rdi) + jz(.DCOLSTORED) + + label(.DROWSTORED) + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm4) + vextractf128(imm(0x1), ymm4, xmm4) + vmovupd(xmm4, mem(rcx, 0*32+2*8)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm5) + vmovupd(ymm5, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm6) + vextractf128(imm(0x1), ymm6, xmm6) + vmovhpd(xmm6, mem(rcx, 0*32+3*8)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm7) + vmovupd(ymm7, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm9) + vmovupd(ymm9, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm11) + vmovhpd(xmm11, mem(rcx, 1*32+1*8)) + vextractf128(imm(0x1), ymm11, xmm11) + vmovupd(xmm11, mem(rcx, 1*32+2*8)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm13) + vextractf128(imm(0x1), ymm13, xmm13) + vmovupd(xmm13, mem(rcx, 1*32+2*8)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm15) + vextractf128(imm(0x1), ymm15, xmm15) + vmovhpd(xmm15, mem(rcx, 1*32+3*8)) + //add(rdi, rcx) + + + jmp(.DDONE) + + + label(.DCOLSTORED) + + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm8) + vfmadd231pd(mem(rcx, rax, 1), ymm3, ymm10) + vmovlpd(xmm8, mem(rcx, rsi, 2)) + vmovupd(xmm10, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + lea(mem(rdx, rsi, 4), rdx) + + // begin I/O on columns 4-7 + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + vperm2f128(imm(0x31), ymm2, ymm0, ymm9) + vperm2f128(imm(0x31), ymm3, ymm1, ymm11) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx ), ymm3, ymm5) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm7) + vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm9) + vfmadd231pd(mem(rcx, rax, 1), ymm3, ymm11) + vmovupd(xmm5, mem(rcx )) + vextractf128(imm(0x1), ymm5, xmm5) + vmovlpd(xmm5, mem(rcx, 2*8 )) + vmovupd(ymm7, mem(rcx, rsi, 1)) + vmovupd(ymm9, mem(rcx, rsi, 2)) + vmovupd(ymm11, mem(rcx, rax, 1)) + + vunpcklpd(ymm15, ymm13, ymm0) + vunpckhpd(ymm15, ymm13, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vfmadd231pd(mem(rdx ), xmm3, xmm0) + vfmadd231pd(mem(rdx, rsi, 1), xmm3, xmm1) + vfmadd231pd(mem(rdx, rsi, 2), xmm3, xmm2) + vfmadd231pd(mem(rdx, rax, 1), xmm3, xmm4) + vmovlpd(xmm2, mem(rdx, rsi, 2)) + vmovupd(xmm4, mem(rdx, rax, 1)) + + jmp(.DDONE) // jump to end. + + label(.DBETAZERO) + + + cmp(imm(8), rdi) + jz(.DCOLSTORBZ) + + label(.DROWSTORBZ) + + vextractf128(imm(0x1), ymm4, xmm4) + vmovupd(xmm4, mem(rcx, 0*32+2*8)) + vmovupd(ymm5, mem(rcx, 1*32)) + add(rdi, rcx) + + vextractf128(imm(0x1), ymm6, xmm6) + vmovhpd(xmm6, mem(rcx, 0*32+3*8)) + vmovupd(ymm7, mem(rcx, 1*32)) + add(rdi, rcx) + + vmovupd(ymm9, mem(rcx, 1*32)) + add(rdi, rcx) + + vmovhpd(xmm11, mem(rcx, 1*32+1*8)) + vextractf128(imm(0x1), ymm11, xmm11) + vmovupd(xmm11, mem(rcx, 1*32+2*8)) + add(rdi, rcx) + + vextractf128(imm(0x1), ymm13, xmm13) + vmovupd(xmm13, mem(rcx, 1*32+2*8)) + add(rdi, rcx) + + vextractf128(imm(0x1), ymm15, xmm15) + vmovhpd(xmm15, mem(rcx, 1*32+3*8)) + + jmp(.DDONE) + + label(.DCOLSTORBZ) + + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vmovlpd(xmm8, mem(rcx, rsi, 2)) + vmovupd(xmm10, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + + lea(mem(rdx, rsi, 4), rdx) + + // begin I/O on columns 4-7 + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + vperm2f128(imm(0x31), ymm2, ymm0, ymm9) + vperm2f128(imm(0x31), ymm3, ymm1, ymm11) + + vmovupd(xmm5, mem(rcx )) + vextractf128(imm(0x1), ymm5, xmm5) + vmovlpd(xmm5, mem(rcx, 2*8 )) + vmovupd(ymm7, mem(rcx, rsi, 1)) + vmovupd(ymm9, mem(rcx, rsi, 2)) + vmovupd(ymm11, mem(rcx, rax, 1)) + + vunpcklpd(ymm15, ymm13, ymm0) + vunpckhpd(ymm15, ymm13, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vmovlpd(xmm2, mem(rdx, rsi, 2)) + vmovupd(xmm4, mem(rdx, rax, 1)) + + + label(.DDONE) + vzeroupper() + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [ps_a8] "m" (ps_a8), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +/* + +Following kernel computes the 6x8 block for the Upper vairant(U) of gemmt where +m_offset in 24x24 block is 0, n_offset is 0(0x0) and m_offset is 6, n_offset is 0 (6x0) +(0x0)+(6x0)_L + +the region marked with 'x' is computed by following kernel +the region marked with '-' is not computed + + <-- n_off_24 -- > + 0 1 2 3 4 5 6 7 + +ā†‘ 0 x x x x x x x x +| 1 - x x x x x x x +m 2 - - x x x x x x +off 3 - - - x x x x x +24 4 - - - - x x x x +| 5 - - - - - x x x +ā†“ +ā†‘ 6 - - - - - - x x +| 7 - - - - - - - x +m 8 - - - - - - - - +off 9 - - - - - - - - +24 10 - - - - - - - - +| 11 - - - - - - - - +ā†“ + + +*/ +void bli_dgemmsup_rv_haswell_asm_6x8m_0x0_combined_U + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_7); + + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + uint64_t ps_a8 = bli_auxinfo_ps_a( data ) * sizeof( double ); + + begin_asm() + + + mov(var(a), r14) + mov(var(rs_a), r8) + mov(var(cs_a), r9) + lea(mem(, r8, 8), r8) + lea(mem(, r9, 8), r9) + + lea(mem(r8, r8, 2), r13) + lea(mem(r8, r8, 4), r15) + + mov(var(rs_b), r10) + lea(mem(, r10, 8), r10) + + mov(var(c), r12) + mov(var(rs_c), rdi) + lea(mem(, rdi, 8), rdi) + + vxorpd(ymm4, ymm4, ymm4) + vmovapd( ymm4, ymm5) + vmovapd( ymm4, ymm6) + vmovapd( ymm4, ymm7) + vmovapd( ymm4, ymm8) + vmovapd( ymm4, ymm9) + vmovapd( ymm4, ymm10) + vmovapd( ymm4, ymm11) + vmovapd( ymm4, ymm12) + vmovapd( ymm4, ymm13) + vmovapd( ymm4, ymm14) + vmovapd( ymm4, ymm15) + + mov(var(b), rbx) + mov(r14, rax) + + + + cmp(imm(8), rdi) + jz(.DCOLPFETCH) + label(.DROWPFETCH) + + lea(mem(r12, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(r12, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(r12, rdi, 1, 7*8)) // prefetch c + 1*rs_c + prefetch(0, mem(r12, rdi, 2, 7*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 7*8)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, rdi, 1, 7*8)) // prefetch c + 4*rs_c + prefetch(0, mem(rdx, rdi, 2, 7*8)) // prefetch c + 5*rs_c + + jmp(.DPOSTPFETCH) + label(.DCOLPFETCH) + + mov(var(cs_c), rsi) + lea(mem(, rsi, 8), rsi) + lea(mem(r12, rsi, 2), rdx) + lea(mem(rdx, rsi, 1), rdx) + prefetch(0, mem(r12, 5*8)) // prefetch c + 0*cs_c + prefetch(0, mem(r12, rsi, 1, 5*8)) // prefetch c + 1*cs_c + prefetch(0, mem(r12, rsi, 2, 5*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 5*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rdx, rsi, 1, 5*8)) // prefetch c + 4*cs_c + prefetch(0, mem(rdx, rsi, 2, 5*8)) // prefetch c + 5*cs_c + lea(mem(rdx, rsi, 2), rdx) // rdx = c + 5*cs_c; + prefetch(0, mem(rdx, rsi, 1, 5*8)) // prefetch c + 6*cs_c + prefetch(0, mem(rdx, rsi, 2, 5*8)) // prefetch c + 7*cs_c + + label(.DPOSTPFETCH) + + mov(var(ps_a8), rdx) + lea(mem(rax, rdx, 1), rdx) + lea(mem(r9, r9, 2), rcx) + + mov(var(k_iter), rsi) + test(rsi, rsi) + je(.DCONSIDKLEFT) + + //ymm12 and ymm14 are used for 0x6 block + label(.DLOOPKITER) // MAIN LOOP + //0 + prefetch(0, mem(rdx, 5*8)) + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm1, ymm3, ymm15) + + vmovupd(mem(rbx, 1*64), ymm0) + add(r10, rbx) // b += rs_b; + lea(mem(rax, r13, 2), rbp) + vbroadcastsd(mem(rbp ), ymm2) + vbroadcastsd(mem(rbp, r8, 1), ymm3) + vfmadd231pd(ymm1, ymm2, ymm12) + vfmadd231pd(ymm1, ymm3, ymm14) + + add(r9, rax) // a += cs_a; + + + + //1 + prefetch(0, mem(rdx, r9, 1, 5*8)) + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm1, ymm3, ymm15) + + vmovupd(mem(rbx, 1*64), ymm0) + add(r10, rbx) // b += rs_b; + lea(mem(rax, r13, 2), rbp) + vbroadcastsd(mem(rbp ), ymm2) + vbroadcastsd(mem(rbp, r8, 1), ymm3) + vfmadd231pd(ymm1, ymm2, ymm12) + vfmadd231pd(ymm1, ymm3, ymm14) + + add(r9, rax) // a += cs_a; + + + //2 + + prefetch(0, mem(rdx, r9, 2, 5*8)) + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm1, ymm3, ymm15) + + vmovupd(mem(rbx, 1*64), ymm0) + add(r10, rbx) // b += rs_b; + lea(mem(rax, r13, 2), rbp) + vbroadcastsd(mem(rbp ), ymm2) + vbroadcastsd(mem(rbp, r8, 1), ymm3) + vfmadd231pd(ymm1, ymm2, ymm12) + vfmadd231pd(ymm1, ymm3, ymm14) + add(r9, rax) // a += cs_a; + + + //3 + prefetch(0, mem(rdx, rcx, 1, 5*8)) + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm1, ymm3, ymm15) + + vmovupd(mem(rbx, 1*64), ymm0) + add(r10, rbx) // b += rs_b; + lea(mem(rax, r13, 2), rbp) + vbroadcastsd(mem(rbp ), ymm2) + vbroadcastsd(mem(rbp, r8, 1), ymm3) + vfmadd231pd(ymm1, ymm2, ymm12) + vfmadd231pd(ymm1, ymm3, ymm14) + add(r9, rax) // a += cs_a; + + + dec(rsi) + jne(.DLOOPKITER) + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) + test(rsi, rsi) + je(.DPOSTACCUM) + + label(.DLOOPKLEFT) // EDGE LOOP + + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm1, ymm3, ymm15) + + vmovupd(mem(rbx, 1*64), ymm0) + add(r10, rbx) // b += rs_b; + lea(mem(rax, r13, 2), rbp) + vbroadcastsd(mem(rbp ), ymm2) + vbroadcastsd(mem(rbp, r8, 1), ymm3) + vfmadd231pd(ymm1, ymm2, ymm12) + vfmadd231pd(ymm1, ymm3, ymm14) + add(r9, rax) // a += cs_a; + + + dec(rsi) + jne(.DLOOPKLEFT) + + label(.DPOSTACCUM) + + + + mov(r12, rcx) + mov(var(alpha), rax) + mov(var(beta), rbx) + vbroadcastsd(mem(rax), ymm0) + vbroadcastsd(mem(rbx), ymm3) + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + vmulpd(ymm0, ymm7, ymm7) + vmulpd(ymm0, ymm8, ymm8) + vmulpd(ymm0, ymm9, ymm9) + vmulpd(ymm0, ymm10, ymm10) + vmulpd(ymm0, ymm11, ymm11) + vmulpd(ymm0, ymm12, ymm12) + vmulpd(ymm0, ymm13, ymm13) + vmulpd(ymm0, ymm14, ymm14) + vmulpd(ymm0, ymm15, ymm15) + + mov(var(cs_c), rsi) + lea(mem(, rsi, 8), rsi) + lea(mem(rcx, rdi, 4), rdx) // c + 4*rs_c; + lea(mem(rsi, rsi, 2), rax) // 3*cs_c; + + + vxorpd(ymm0, ymm0, ymm0) + vucomisd(xmm0, xmm3) + je(.DBETAZERO) + + cmp(imm(8), rdi) + jz(.DCOLSTORED) + + label(.DROWSTORED) + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm4) + vmovupd(ymm4, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm5) + vmovupd(ymm5, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm6) + vmovhpd(xmm6, mem(rcx, 0*32+1*8)) + vextractf128(imm(0x1), ymm6, xmm6) + vmovupd(xmm6, mem(rcx, 0*32+2*8)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm7) + vmovupd(ymm7, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm8) + vextractf128(imm(0x1), ymm8, xmm8) + vmovupd(xmm8, mem(rcx, 0*32+2*8)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm9) + vmovupd(ymm9, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm10) + vextractf128(imm(0x1), ymm10, xmm10) + vmovhpd(xmm10, mem(rcx, 0*32+3*8)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm11) + vmovupd(ymm11, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, rdi, 2, 1*32), ymm3, ymm12) + vextractf128(imm(0x1), ymm12, xmm12) + vmovupd(xmm12, mem(rcx, rdi, 2, 1*32+2*8)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm13) + vmovupd(ymm13, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, rdi, 2, 1*32), ymm3, ymm14) + vextractf128(imm(0x1), ymm14, xmm14) + vmovhpd(xmm14, mem(rcx, rdi, 2, 1*32+3*8)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm15) + vmovhpd(xmm15, mem(rcx, 1*32+1*8)) + vextractf128(imm(0x1), ymm15, xmm15) + vmovupd(xmm15, mem(rcx, 1*32+2*8)) + //add(rdi, rcx) + + + jmp(.DDONE) + + + label(.DCOLSTORED) + + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx ), ymm3, ymm4) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm6) + vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm8) + vfmadd231pd(mem(rcx, rax, 1), ymm3, ymm10) + vmovlpd(xmm4, mem(rcx )) + vmovupd(xmm6, mem(rcx, rsi, 1)) + vmovupd(xmm8, mem(rcx, rsi, 2)) + vextractf128(imm(0x1), ymm8, xmm8) + vmovlpd(xmm8, mem(rcx, rsi, 2, 1*16)) + vmovupd(ymm10, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + vunpcklpd(ymm14, ymm12, ymm0) + vunpckhpd(ymm14, ymm12, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + lea(mem(rcx, 6*8), rbp) + lea(mem(rbp, rsi, 2), rbp) + vfmadd231pd(mem(rbp ), xmm3, xmm2) + vfmadd231pd(mem(rbp, rsi, 1), xmm3, xmm4) + vmovlpd(xmm2, mem(rbp)) + vmovupd(xmm4, mem(rbp, rsi, 1)) + + lea(mem(rdx, rsi, 4), rdx) + + // begin I/O on columns 4-7 + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + vperm2f128(imm(0x31), ymm2, ymm0, ymm9) + vperm2f128(imm(0x31), ymm3, ymm1, ymm11) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx ), ymm3, ymm5) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm7) + vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm9) + vfmadd231pd(mem(rcx, rax, 1), ymm3, ymm11) + vmovupd(ymm5, mem(rcx )) + vmovupd(ymm7, mem(rcx, rsi, 1)) + vmovupd(ymm9, mem(rcx, rsi, 2)) + vmovupd(ymm11, mem(rcx, rax, 1)) + + vunpcklpd(ymm15, ymm13, ymm0) + vunpckhpd(ymm15, ymm13, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vfmadd231pd(mem(rdx ), xmm3, xmm0) + vfmadd231pd(mem(rdx, rsi, 1), xmm3, xmm1) + vfmadd231pd(mem(rdx, rsi, 2), xmm3, xmm2) + vfmadd231pd(mem(rdx, rax, 1), xmm3, xmm4) + vmovlpd(xmm0, mem(rdx )) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovupd(xmm4, mem(rdx, rax, 1)) + + jmp(.DDONE) // jump to end. + + label(.DBETAZERO) + + + cmp(imm(8), rdi) + jz(.DCOLSTORBZ) + + label(.DROWSTORBZ) + + vmovupd(ymm4, mem(rcx, 0*32)) + vmovupd(ymm5, mem(rcx, 1*32)) + add(rdi, rcx) + + vmovhpd(xmm6, mem(rcx, 0*32+1*8)) + vextractf128(imm(0x1), ymm6, xmm6) + vmovupd(xmm6, mem(rcx, 0*32+2*8)) + vmovupd(ymm7, mem(rcx, 1*32)) + add(rdi, rcx) + + vextractf128(imm(0x1), ymm8, xmm8) + vmovupd(xmm8, mem(rcx, 0*32+2*8)) + vmovupd(ymm9, mem(rcx, 1*32)) + add(rdi, rcx) + + vextractf128(imm(0x1), ymm10, xmm10) + vmovhpd(xmm10, mem(rcx, 0*32+3*8)) + vmovupd(ymm11, mem(rcx, 1*32)) + add(rdi, rcx) + + + vextractf128(imm(0x1), ymm12, xmm12) + vmovupd(xmm12, mem(rcx, rdi, 2, 1*32+2*8)) + vmovupd(ymm13, mem(rcx, 1*32)) + add(rdi, rcx) + + + vextractf128(imm(0x1), ymm14, xmm14) + vmovhpd(xmm14, mem(rcx, rdi, 2, 1*32+3*8)) + vmovhpd(xmm15, mem(rcx, 1*32+1*8)) + vextractf128(imm(0x1), ymm15, xmm15) + vmovupd(xmm15, mem(rcx, 1*32+2*8)) + + jmp(.DDONE) + + label(.DCOLSTORBZ) + + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vmovlpd(xmm4, mem(rcx )) + vmovupd(xmm6, mem(rcx, rsi, 1)) + vmovupd(xmm8, mem(rcx, rsi, 2)) + vextractf128(imm(0x1), ymm8, xmm8) + vmovlpd(xmm8, mem(rcx, rsi, 2, 1*16)) + vmovupd(ymm10, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + vunpcklpd(ymm14, ymm12, ymm0) + vunpckhpd(ymm14, ymm12, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + lea(mem(rcx, rdi, 4), rbp) + lea(mem(rbp, rdi, 2), rbp) + lea(mem(rbp, rsi, 2), rbp) + vmovlpd(xmm2, mem(rbp)) + vmovupd(xmm4, mem(rbp, rsi, 1)) + + lea(mem(rdx, rsi, 4), rdx) + + // begin I/O on columns 4-7 + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + vperm2f128(imm(0x31), ymm2, ymm0, ymm9) + vperm2f128(imm(0x31), ymm3, ymm1, ymm11) + + vmovupd(ymm5, mem(rcx )) + vmovupd(ymm7, mem(rcx, rsi, 1)) + vmovupd(ymm9, mem(rcx, rsi, 2)) + vmovupd(ymm11, mem(rcx, rax, 1)) + + vunpcklpd(ymm15, ymm13, ymm0) + vunpckhpd(ymm15, ymm13, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vmovlpd(xmm0, mem(rdx )) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovupd(xmm4, mem(rdx, rax, 1)) + + + label(.DDONE) + vzeroupper() + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [ps_a8] "m" (ps_a8), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rv_haswell_asm_6x6m + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_7); + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t m_iter = m0 / 6; + uint64_t m_left = m0 % 6; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // Query the panel stride of A and convert it to units of bytes. + uint64_t ps_a = bli_auxinfo_ps_a( data ); + uint64_t ps_a8 = ps_a * sizeof( double ); + + if ( m_iter == 0 ) goto consider_edge_cases; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + //mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + // During preamble and loops: + // r12 = rcx = c + // r14 = rax = a + // read rbx from var(b) near beginning of loop + // r11 = m dim index ii + + mov(var(m_iter), r11) // ii = m_iter; + + label(.DLOOP6X8I) // LOOP OVER ii = [ m_iter ... 1 0 ] + + + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm1, ymm1, ymm1) // zero ymm1 since we only use the lower + vxorpd(ymm4, ymm4, ymm4) // half (xmm1), and nans/infs may slow us + vmovapd( ymm4, ymm5) + vmovapd( ymm4, ymm6) + vmovapd( ymm4, ymm7) + vmovapd( ymm4, ymm8) + vmovapd( ymm4, ymm9) + vmovapd( ymm4, ymm10) + vmovapd( ymm4, ymm11) + vmovapd( ymm4, ymm12) + vmovapd( ymm4, ymm13) + vmovapd( ymm4, ymm14) + vmovapd( ymm4, ymm15) +#endif + + mov(var(b), rbx) // load address of b. + //mov(r12, rcx) // reset rcx to current utile of c. + mov(r14, rax) // reset rax to current upanel of a. + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + lea(mem(r12, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(r12, 5*8)) // prefetch c + 0*rs_c + prefetch(0, mem(r12, rdi, 1, 5*8)) // prefetch c + 1*rs_c + prefetch(0, mem(r12, rdi, 2, 5*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 5*8)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, rdi, 1, 5*8)) // prefetch c + 4*rs_c + prefetch(0, mem(rdx, rdi, 2, 5*8)) // prefetch c + 5*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + lea(mem(r12, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(r12, 5*8)) // prefetch c + 0*cs_c + prefetch(0, mem(r12, rsi, 1, 5*8)) // prefetch c + 1*cs_c + prefetch(0, mem(r12, rsi, 2, 5*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 5*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rdx, rsi, 1, 5*8)) // prefetch c + 4*cs_c + prefetch(0, mem(rdx, rsi, 2, 5*8)) // prefetch c + 5*cs_c + + label(.DPOSTPFETCH) // done prefetching c + + +#if 1 + mov(var(ps_a8), rdx) // load ps_a8 + lea(mem(rax, rdx, 1), rdx) // rdx = a + ps_a8 + lea(mem(r9, r9, 2), rcx) // rcx = 3*cs_a; + // use rcx, rdx for prefetching lines + // from next upanel of a. +#else + lea(mem(rax, r8, 4), rdx) // use rdx for prefetching lines + lea(mem(rdx, r8, 2), rdx) // from next upanel of a. + lea(mem(r9, r9, 2), rcx) // rcx = 3*cs_a; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 0 + prefetch(0, mem(rdx, 5*8)) +#else + prefetch(0, mem(rdx, 5*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, 5*8)) +#else + prefetch(0, mem(rdx, r9, 1, 5*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + // ---------------------------------- iteration 2 + +#if 0 + prefetch(0, mem(rdx, 5*8)) +#else + prefetch(0, mem(rdx, r9, 2, 5*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + // ---------------------------------- iteration 3 @@ -1275,14 +8379,14 @@ void bli_dgemmsup_rv_haswell_asm_6x6m vfmadd231pd(ymm1, ymm2, ymm5) vfmadd231pd(ymm0, ymm3, ymm6) vfmadd231pd(ymm1, ymm3, ymm7) - + vbroadcastsd(mem(rax, r8, 2), ymm2) vbroadcastsd(mem(rax, r13, 1), ymm3) vfmadd231pd(ymm0, ymm2, ymm8) vfmadd231pd(ymm1, ymm2, ymm9) vfmadd231pd(ymm0, ymm3, ymm10) vfmadd231pd(ymm1, ymm3, ymm11) - + vbroadcastsd(mem(rax, r8, 4), ymm2) vbroadcastsd(mem(rax, r15, 1), ymm3) add(r9, rax) // a += cs_a; @@ -1290,27 +8394,27 @@ void bli_dgemmsup_rv_haswell_asm_6x6m vfmadd231pd(ymm1, ymm2, ymm13) vfmadd231pd(ymm0, ymm3, ymm14) vfmadd231pd(ymm1, ymm3, ymm15) - - - + + + dec(rsi) // i -= 1; jne(.DLOOPKITER) // iterate again if i != 0. - - - - - - + + + + + + label(.DCONSIDKLEFT) - + mov(var(k_left), rsi) // i = k_left; test(rsi, rsi) // check i via logical AND. je(.DPOSTACCUM) // if i == 0, we're done; jump to end. // else, we prepare to enter k_left loop. - - + + label(.DLOOPKLEFT) // EDGE LOOP - + #if 1 prefetch(0, mem(rdx, 5*8)) add(r9, rdx) @@ -1319,21 +8423,21 @@ void bli_dgemmsup_rv_haswell_asm_6x6m vmovupd(mem(rbx, 0*32), ymm0) vmovupd(mem(rbx, 1*32), xmm1) add(r10, rbx) // b += rs_b; - + vbroadcastsd(mem(rax ), ymm2) vbroadcastsd(mem(rax, r8, 1), ymm3) vfmadd231pd(ymm0, ymm2, ymm4) vfmadd231pd(ymm1, ymm2, ymm5) vfmadd231pd(ymm0, ymm3, ymm6) vfmadd231pd(ymm1, ymm3, ymm7) - + vbroadcastsd(mem(rax, r8, 2), ymm2) vbroadcastsd(mem(rax, r13, 1), ymm3) vfmadd231pd(ymm0, ymm2, ymm8) vfmadd231pd(ymm1, ymm2, ymm9) vfmadd231pd(ymm0, ymm3, ymm10) vfmadd231pd(ymm1, ymm3, ymm11) - + vbroadcastsd(mem(rax, r8, 4), ymm2) vbroadcastsd(mem(rax, r15, 1), ymm3) add(r9, rax) // a += cs_a; @@ -1341,53 +8445,53 @@ void bli_dgemmsup_rv_haswell_asm_6x6m vfmadd231pd(ymm1, ymm2, ymm13) vfmadd231pd(ymm0, ymm3, ymm14) vfmadd231pd(ymm1, ymm3, ymm15) - - + + dec(rsi) // i -= 1; jne(.DLOOPKLEFT) // iterate again if i != 0. - - - + + + label(.DPOSTACCUM) - - + + mov(r12, rcx) // reset rcx to current utile of c. mov(var(alpha), rax) // load address of alpha mov(var(beta), rbx) // load address of beta vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate - + vmulpd(ymm0, ymm4, ymm4) // scale by alpha - vmulpd(xmm0, xmm5, xmm5) + vmulpd(ymm0, ymm5, ymm5) vmulpd(ymm0, ymm6, ymm6) - vmulpd(xmm0, xmm7, xmm7) + vmulpd(ymm0, ymm7, ymm7) vmulpd(ymm0, ymm8, ymm8) - vmulpd(xmm0, xmm9, xmm9) + vmulpd(ymm0, ymm9, ymm9) vmulpd(ymm0, ymm10, ymm10) - vmulpd(xmm0, xmm11, xmm11) + vmulpd(ymm0, ymm11, ymm11) vmulpd(ymm0, ymm12, ymm12) - vmulpd(xmm0, xmm13, xmm13) + vmulpd(ymm0, ymm13, ymm13) vmulpd(ymm0, ymm14, ymm14) - vmulpd(xmm0, xmm15, xmm15) - - - - - - + vmulpd(ymm0, ymm15, ymm15) + + + + + + mov(var(cs_c), rsi) // load cs_c lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) - + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; - - - + + + // now avoid loading C if beta == 0 - + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. vucomisd(xmm0, xmm3) // set ZF if beta == 0. je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case @@ -1396,60 +8500,60 @@ void bli_dgemmsup_rv_haswell_asm_6x6m cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. jz(.DCOLSTORED) // jump to column storage case - - + + label(.DROWSTORED) - - + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm4) vmovupd(ymm4, mem(rcx, 0*32)) vfmadd231pd(mem(rcx, 1*32), xmm3, xmm5) vmovupd(xmm5, mem(rcx, 1*32)) add(rdi, rcx) - - + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm6) vmovupd(ymm6, mem(rcx, 0*32)) vfmadd231pd(mem(rcx, 1*32), xmm3, xmm7) vmovupd(xmm7, mem(rcx, 1*32)) add(rdi, rcx) - - + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm8) vmovupd(ymm8, mem(rcx, 0*32)) vfmadd231pd(mem(rcx, 1*32), xmm3, xmm9) vmovupd(xmm9, mem(rcx, 1*32)) add(rdi, rcx) - - + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm10) vmovupd(ymm10, mem(rcx, 0*32)) vfmadd231pd(mem(rcx, 1*32), xmm3, xmm11) vmovupd(xmm11, mem(rcx, 1*32)) add(rdi, rcx) - - + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm12) vmovupd(ymm12, mem(rcx, 0*32)) vfmadd231pd(mem(rcx, 1*32), xmm3, xmm13) vmovupd(xmm13, mem(rcx, 1*32)) add(rdi, rcx) - - + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm14) vmovupd(ymm14, mem(rcx, 0*32)) vfmadd231pd(mem(rcx, 1*32), xmm3, xmm15) vmovupd(xmm15, mem(rcx, 1*32)) //add(rdi, rcx) - - + + jmp(.DDONE) // jump to end. @@ -1524,51 +8628,51 @@ void bli_dgemmsup_rv_haswell_asm_6x6m jmp(.DDONE) // jump to end. - - - - + + + + label(.DBETAZERO) - + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. jz(.DCOLSTORBZ) // jump to column storage case - + label(.DROWSTORBZ) - - + + vmovupd(ymm4, mem(rcx, 0*32)) vmovupd(xmm5, mem(rcx, 1*32)) add(rdi, rcx) - + vmovupd(ymm6, mem(rcx, 0*32)) vmovupd(xmm7, mem(rcx, 1*32)) add(rdi, rcx) - - + + vmovupd(ymm8, mem(rcx, 0*32)) vmovupd(xmm9, mem(rcx, 1*32)) add(rdi, rcx) - - + + vmovupd(ymm10, mem(rcx, 0*32)) vmovupd(xmm11, mem(rcx, 1*32)) add(rdi, rcx) - - + + vmovupd(ymm12, mem(rcx, 0*32)) vmovupd(xmm13, mem(rcx, 1*32)) add(rdi, rcx) - - + + vmovupd(ymm14, mem(rcx, 0*32)) vmovupd(xmm15, mem(rcx, 1*32)) //add(rdi, rcx) - - + + jmp(.DDONE) // jump to end. @@ -1625,9 +8729,9 @@ void bli_dgemmsup_rv_haswell_asm_6x6m //lea(mem(rdx, rsi, 4), rdx) - - - + + + label(.DDONE) @@ -1648,8 +8752,7 @@ void bli_dgemmsup_rv_haswell_asm_6x6m label(.DRETURN) - - + vzeroupper() end_asm( : // output operands (none) @@ -1810,9 +8913,9 @@ void bli_dgemmsup_rv_haswell_asm_6x4m // ------------------------------------------------------------------------- begin_asm() - + //vzeroall() // zero all xmm/ymm registers. - + mov(var(a), r14) // load address of a. mov(var(rs_a), r8) // load rs_a mov(var(cs_a), r9) // load cs_a @@ -1827,7 +8930,7 @@ void bli_dgemmsup_rv_haswell_asm_6x4m //mov(var(cs_b), r11) // load cs_b lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) - + // NOTE: We cannot pre-load elements of a or b // because it could eventually, in the last // unrolled iter or the cleanup loop, result @@ -1858,11 +8961,11 @@ void bli_dgemmsup_rv_haswell_asm_6x4m // a latency of 1 cycle, while vzeroall // has a latency of 12 cycles. vxorpd(ymm4, ymm4, ymm4) - vxorpd(ymm6, ymm6, ymm6) - vxorpd(ymm8, ymm8, ymm8) - vxorpd(ymm10, ymm10, ymm10) - vxorpd(ymm12, ymm12, ymm12) - vxorpd(ymm14, ymm14, ymm14) + vmovapd( ymm4, ymm6) + vmovapd( ymm4, ymm8) + vmovapd( ymm4, ymm10) + vmovapd( ymm4, ymm12) + vmovapd( ymm4, ymm14) #endif mov(var(b), rbx) // load address of b. @@ -1912,17 +9015,17 @@ void bli_dgemmsup_rv_haswell_asm_6x4m #endif - - + + mov(var(k_iter), rsi) // i = k_iter; test(rsi, rsi) // check i via logical AND. je(.DCONSIDKLEFT) // if i == 0, jump to code that // contains the k_left loop. - - + + label(.DLOOPKITER) // MAIN LOOP - - + + // ---------------------------------- iteration 0 #if 0 @@ -1930,7 +9033,7 @@ void bli_dgemmsup_rv_haswell_asm_6x4m #else prefetch(0, mem(rdx, 5*8)) #endif - + vmovupd(mem(rbx, 0*32), ymm0) add(r10, rbx) // b += rs_b; @@ -1938,19 +9041,19 @@ void bli_dgemmsup_rv_haswell_asm_6x4m vbroadcastsd(mem(rax, r8, 1), ymm3) vfmadd231pd(ymm0, ymm2, ymm4) vfmadd231pd(ymm0, ymm3, ymm6) - + vbroadcastsd(mem(rax, r8, 2), ymm2) vbroadcastsd(mem(rax, r13, 1), ymm3) vfmadd231pd(ymm0, ymm2, ymm8) vfmadd231pd(ymm0, ymm3, ymm10) - + vbroadcastsd(mem(rax, r8, 4), ymm2) vbroadcastsd(mem(rax, r15, 1), ymm3) add(r9, rax) // a += cs_a; vfmadd231pd(ymm0, ymm2, ymm12) vfmadd231pd(ymm0, ymm3, ymm14) - + // ---------------------------------- iteration 1 #if 0 @@ -1966,18 +9069,18 @@ void bli_dgemmsup_rv_haswell_asm_6x4m vbroadcastsd(mem(rax, r8, 1), ymm3) vfmadd231pd(ymm0, ymm2, ymm4) vfmadd231pd(ymm0, ymm3, ymm6) - + vbroadcastsd(mem(rax, r8, 2), ymm2) vbroadcastsd(mem(rax, r13, 1), ymm3) vfmadd231pd(ymm0, ymm2, ymm8) vfmadd231pd(ymm0, ymm3, ymm10) - + vbroadcastsd(mem(rax, r8, 4), ymm2) vbroadcastsd(mem(rax, r15, 1), ymm3) add(r9, rax) // a += cs_a; vfmadd231pd(ymm0, ymm2, ymm12) vfmadd231pd(ymm0, ymm3, ymm14) - + // ---------------------------------- iteration 2 @@ -1986,7 +9089,7 @@ void bli_dgemmsup_rv_haswell_asm_6x4m #else prefetch(0, mem(rdx, r9, 2, 5*8)) #endif - + vmovupd(mem(rbx, 0*32), ymm0) add(r10, rbx) // b += rs_b; @@ -1994,18 +9097,18 @@ void bli_dgemmsup_rv_haswell_asm_6x4m vbroadcastsd(mem(rax, r8, 1), ymm3) vfmadd231pd(ymm0, ymm2, ymm4) vfmadd231pd(ymm0, ymm3, ymm6) - + vbroadcastsd(mem(rax, r8, 2), ymm2) vbroadcastsd(mem(rax, r13, 1), ymm3) vfmadd231pd(ymm0, ymm2, ymm8) vfmadd231pd(ymm0, ymm3, ymm10) - + vbroadcastsd(mem(rax, r8, 4), ymm2) vbroadcastsd(mem(rax, r15, 1), ymm3) add(r9, rax) // a += cs_a; vfmadd231pd(ymm0, ymm2, ymm12) vfmadd231pd(ymm0, ymm3, ymm14) - + // ---------------------------------- iteration 3 @@ -2023,38 +9126,38 @@ void bli_dgemmsup_rv_haswell_asm_6x4m vbroadcastsd(mem(rax, r8, 1), ymm3) vfmadd231pd(ymm0, ymm2, ymm4) vfmadd231pd(ymm0, ymm3, ymm6) - + vbroadcastsd(mem(rax, r8, 2), ymm2) vbroadcastsd(mem(rax, r13, 1), ymm3) vfmadd231pd(ymm0, ymm2, ymm8) vfmadd231pd(ymm0, ymm3, ymm10) - + vbroadcastsd(mem(rax, r8, 4), ymm2) vbroadcastsd(mem(rax, r15, 1), ymm3) add(r9, rax) // a += cs_a; vfmadd231pd(ymm0, ymm2, ymm12) vfmadd231pd(ymm0, ymm3, ymm14) - - - + + + dec(rsi) // i -= 1; jne(.DLOOPKITER) // iterate again if i != 0. - - - - - - + + + + + + label(.DCONSIDKLEFT) - + mov(var(k_left), rsi) // i = k_left; test(rsi, rsi) // check i via logical AND. je(.DPOSTACCUM) // if i == 0, we're done; jump to end. // else, we prepare to enter k_left loop. - - + + label(.DLOOPKLEFT) // EDGE LOOP - + #if 1 prefetch(0, mem(rdx, 5*8)) add(r9, rdx) @@ -2067,58 +9170,58 @@ void bli_dgemmsup_rv_haswell_asm_6x4m vbroadcastsd(mem(rax, r8, 1), ymm3) vfmadd231pd(ymm0, ymm2, ymm4) vfmadd231pd(ymm0, ymm3, ymm6) - + vbroadcastsd(mem(rax, r8, 2), ymm2) vbroadcastsd(mem(rax, r13, 1), ymm3) vfmadd231pd(ymm0, ymm2, ymm8) vfmadd231pd(ymm0, ymm3, ymm10) - + vbroadcastsd(mem(rax, r8, 4), ymm2) vbroadcastsd(mem(rax, r15, 1), ymm3) add(r9, rax) // a += cs_a; vfmadd231pd(ymm0, ymm2, ymm12) vfmadd231pd(ymm0, ymm3, ymm14) - - + + dec(rsi) // i -= 1; jne(.DLOOPKLEFT) // iterate again if i != 0. - - - + + + label(.DPOSTACCUM) - + mov(r12, rcx) // reset rcx to current utile of c. mov(var(alpha), rax) // load address of alpha mov(var(beta), rbx) // load address of beta vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate - + vmulpd(ymm0, ymm4, ymm4) // scale by alpha vmulpd(ymm0, ymm6, ymm6) vmulpd(ymm0, ymm8, ymm8) vmulpd(ymm0, ymm10, ymm10) vmulpd(ymm0, ymm12, ymm12) vmulpd(ymm0, ymm14, ymm14) - - - - - - + + + + + + mov(var(cs_c), rsi) // load cs_c lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) - + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; - - - + + + // now avoid loading C if beta == 0 - + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. vucomisd(xmm0, xmm3) // set ZF if beta == 0. je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case @@ -2127,42 +9230,42 @@ void bli_dgemmsup_rv_haswell_asm_6x4m cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. jz(.DCOLSTORED) // jump to column storage case - - + + label(.DROWSTORED) - - + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm4) vmovupd(ymm4, mem(rcx, 0*32)) add(rdi, rcx) - - + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm6) vmovupd(ymm6, mem(rcx, 0*32)) add(rdi, rcx) - - + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm8) vmovupd(ymm8, mem(rcx, 0*32)) add(rdi, rcx) - - + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm10) vmovupd(ymm10, mem(rcx, 0*32)) add(rdi, rcx) - - + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm12) vmovupd(ymm12, mem(rcx, 0*32)) add(rdi, rcx) - - + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm14) vmovupd(ymm14, mem(rcx, 0*32)) //add(rdi, rcx) - - + + jmp(.DDONE) // jump to end. @@ -2210,45 +9313,45 @@ void bli_dgemmsup_rv_haswell_asm_6x4m jmp(.DDONE) // jump to end. - - - - + + + + label(.DBETAZERO) cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. jz(.DCOLSTORBZ) // jump to column storage case - - + + label(.DROWSTORBZ) - - + + vmovupd(ymm4, mem(rcx, 0*32)) add(rdi, rcx) - + vmovupd(ymm6, mem(rcx, 0*32)) add(rdi, rcx) - - + + vmovupd(ymm8, mem(rcx, 0*32)) add(rdi, rcx) - + vmovupd(ymm10, mem(rcx, 0*32)) add(rdi, rcx) - - + + vmovupd(ymm12, mem(rcx, 0*32)) add(rdi, rcx) - + vmovupd(ymm14, mem(rcx, 0*32)) //add(rdi, rcx) - + jmp(.DDONE) // jump to end. @@ -2283,15 +9386,15 @@ void bli_dgemmsup_rv_haswell_asm_6x4m vmovupd(xmm4, mem(rdx, rax, 1)) //lea(mem(rdx, rsi, 4), rdx) - - - - + + + + label(.DDONE) - + lea(mem(r12, rdi, 4), r12) // lea(mem(r12, rdi, 2), r12) // c_ii = r12 += 6*rs_c @@ -2307,8 +9410,7 @@ void bli_dgemmsup_rv_haswell_asm_6x4m label(.DRETURN) - - + vzeroupper() end_asm( : // output operands (none) @@ -2469,9 +9571,9 @@ void bli_dgemmsup_rv_haswell_asm_6x2m // ------------------------------------------------------------------------- begin_asm() - + //vzeroall() // zero all xmm/ymm registers. - + mov(var(a), r14) // load address of a. mov(var(rs_a), r8) // load rs_a mov(var(cs_a), r9) // load cs_a @@ -2486,7 +9588,7 @@ void bli_dgemmsup_rv_haswell_asm_6x2m //mov(var(cs_b), r11) // load cs_b lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) - + // NOTE: We cannot pre-load elements of a or b // because it could eventually, in the last // unrolled iter or the cleanup loop, result @@ -2517,11 +9619,11 @@ void bli_dgemmsup_rv_haswell_asm_6x2m // a latency of 1 cycle, while vzeroall // has a latency of 12 cycles. vxorpd(xmm4, xmm4, xmm4) - vxorpd(xmm6, xmm6, xmm6) - vxorpd(xmm8, xmm8, xmm8) - vxorpd(xmm10, xmm10, xmm10) - vxorpd(xmm12, xmm12, xmm12) - vxorpd(xmm14, xmm14, xmm14) + vmovapd( ymm4, ymm6) + vmovapd( ymm4, ymm8) + vmovapd( ymm4, ymm10) + vmovapd( ymm4, ymm12) + vmovapd( ymm4, ymm14) #endif mov(var(b), rbx) // load address of b. @@ -2565,19 +9667,19 @@ void bli_dgemmsup_rv_haswell_asm_6x2m lea(mem(rdx, r8, 2), rdx) // from next upanel of a. lea(mem(r9, r9, 2), rcx) // rcx = 3*cs_a; #endif - - - - + + + + mov(var(k_iter), rsi) // i = k_iter; test(rsi, rsi) // check i via logical AND. je(.DCONSIDKLEFT) // if i == 0, jump to code that // contains the k_left loop. - - + + label(.DLOOPKITER) // MAIN LOOP - - + + // ---------------------------------- iteration 0 #if 0 @@ -2585,7 +9687,7 @@ void bli_dgemmsup_rv_haswell_asm_6x2m #else prefetch(0, mem(rdx, 5*8)) #endif - + vmovupd(mem(rbx, 0*32), xmm0) add(r10, rbx) // b += rs_b; @@ -2593,19 +9695,19 @@ void bli_dgemmsup_rv_haswell_asm_6x2m vbroadcastsd(mem(rax, r8, 1), ymm3) vfmadd231pd(xmm0, xmm2, xmm4) vfmadd231pd(xmm0, xmm3, xmm6) - + vbroadcastsd(mem(rax, r8, 2), ymm2) vbroadcastsd(mem(rax, r13, 1), ymm3) vfmadd231pd(xmm0, xmm2, xmm8) vfmadd231pd(xmm0, xmm3, xmm10) - + vbroadcastsd(mem(rax, r8, 4), ymm2) vbroadcastsd(mem(rax, r15, 1), ymm3) add(r9, rax) // a += cs_a; vfmadd231pd(xmm0, xmm2, xmm12) vfmadd231pd(xmm0, xmm3, xmm14) - + // ---------------------------------- iteration 1 #if 0 @@ -2621,18 +9723,18 @@ void bli_dgemmsup_rv_haswell_asm_6x2m vbroadcastsd(mem(rax, r8, 1), ymm3) vfmadd231pd(xmm0, xmm2, xmm4) vfmadd231pd(xmm0, xmm3, xmm6) - + vbroadcastsd(mem(rax, r8, 2), ymm2) vbroadcastsd(mem(rax, r13, 1), ymm3) vfmadd231pd(xmm0, xmm2, xmm8) vfmadd231pd(xmm0, xmm3, xmm10) - + vbroadcastsd(mem(rax, r8, 4), ymm2) vbroadcastsd(mem(rax, r15, 1), ymm3) add(r9, rax) // a += cs_a; vfmadd231pd(xmm0, xmm2, xmm12) vfmadd231pd(xmm0, xmm3, xmm14) - + // ---------------------------------- iteration 2 @@ -2641,7 +9743,7 @@ void bli_dgemmsup_rv_haswell_asm_6x2m #else prefetch(0, mem(rdx, r9, 2, 5*8)) #endif - + vmovupd(mem(rbx, 0*32), xmm0) add(r10, rbx) // b += rs_b; @@ -2649,18 +9751,18 @@ void bli_dgemmsup_rv_haswell_asm_6x2m vbroadcastsd(mem(rax, r8, 1), ymm3) vfmadd231pd(xmm0, xmm2, xmm4) vfmadd231pd(xmm0, xmm3, xmm6) - + vbroadcastsd(mem(rax, r8, 2), ymm2) vbroadcastsd(mem(rax, r13, 1), ymm3) vfmadd231pd(xmm0, xmm2, xmm8) vfmadd231pd(xmm0, xmm3, xmm10) - + vbroadcastsd(mem(rax, r8, 4), ymm2) vbroadcastsd(mem(rax, r15, 1), ymm3) add(r9, rax) // a += cs_a; vfmadd231pd(xmm0, xmm2, xmm12) vfmadd231pd(xmm0, xmm3, xmm14) - + // ---------------------------------- iteration 3 @@ -2678,43 +9780,43 @@ void bli_dgemmsup_rv_haswell_asm_6x2m vbroadcastsd(mem(rax, r8, 1), ymm3) vfmadd231pd(xmm0, xmm2, xmm4) vfmadd231pd(xmm0, xmm3, xmm6) - + vbroadcastsd(mem(rax, r8, 2), ymm2) vbroadcastsd(mem(rax, r13, 1), ymm3) vfmadd231pd(xmm0, xmm2, xmm8) vfmadd231pd(xmm0, xmm3, xmm10) - + vbroadcastsd(mem(rax, r8, 4), ymm2) vbroadcastsd(mem(rax, r15, 1), ymm3) add(r9, rax) // a += cs_a; vfmadd231pd(xmm0, xmm2, xmm12) vfmadd231pd(xmm0, xmm3, xmm14) - - - + + + dec(rsi) // i -= 1; jne(.DLOOPKITER) // iterate again if i != 0. - - - - - - + + + + + + label(.DCONSIDKLEFT) - + mov(var(k_left), rsi) // i = k_left; test(rsi, rsi) // check i via logical AND. je(.DPOSTACCUM) // if i == 0, we're done; jump to end. // else, we prepare to enter k_left loop. - - + + label(.DLOOPKLEFT) // EDGE LOOP #if 1 prefetch(0, mem(rdx, 5*8)) add(r9, rdx) #endif - + vmovupd(mem(rbx, 0*32), xmm0) add(r10, rbx) // b += rs_b; @@ -2722,58 +9824,57 @@ void bli_dgemmsup_rv_haswell_asm_6x2m vbroadcastsd(mem(rax, r8, 1), ymm3) vfmadd231pd(xmm0, xmm2, xmm4) vfmadd231pd(xmm0, xmm3, xmm6) - + vbroadcastsd(mem(rax, r8, 2), ymm2) vbroadcastsd(mem(rax, r13, 1), ymm3) vfmadd231pd(xmm0, xmm2, xmm8) vfmadd231pd(xmm0, xmm3, xmm10) - + vbroadcastsd(mem(rax, r8, 4), ymm2) vbroadcastsd(mem(rax, r15, 1), ymm3) add(r9, rax) // a += cs_a; vfmadd231pd(xmm0, xmm2, xmm12) vfmadd231pd(xmm0, xmm3, xmm14) - - + + dec(rsi) // i -= 1; jne(.DLOOPKLEFT) // iterate again if i != 0. - - - + + + label(.DPOSTACCUM) - + mov(r12, rcx) // reset rcx to current utile of c. mov(var(alpha), rax) // load address of alpha mov(var(beta), rbx) // load address of beta vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate - - vmulpd(xmm0, xmm4, xmm4) // scale by alpha - vmulpd(xmm0, xmm6, xmm6) - vmulpd(xmm0, xmm8, xmm8) - vmulpd(xmm0, xmm10, xmm10) - vmulpd(xmm0, xmm12, xmm12) - vmulpd(xmm0, xmm14, xmm14) - - - - - - + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm6, ymm6) + vmulpd(ymm0, ymm8, ymm8) + vmulpd(ymm0, ymm10, ymm10) + vmulpd(ymm0, ymm12, ymm12) + vmulpd(ymm0, ymm14, ymm14) + + + + + + mov(var(cs_c), rsi) // load cs_c lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) - - //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; //lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; - - - + + + // now avoid loading C if beta == 0 - + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. vucomisd(xmm0, xmm3) // set ZF if beta == 0. je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case @@ -2782,42 +9883,42 @@ void bli_dgemmsup_rv_haswell_asm_6x2m cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. jz(.DCOLSTORED) // jump to column storage case - - + + label(.DROWSTORED) - - + + vfmadd231pd(mem(rcx, 0*32), xmm3, xmm4) vmovupd(xmm4, mem(rcx, 0*32)) add(rdi, rcx) - - + + vfmadd231pd(mem(rcx, 0*32), xmm3, xmm6) vmovupd(xmm6, mem(rcx, 0*32)) add(rdi, rcx) - - + + vfmadd231pd(mem(rcx, 0*32), xmm3, xmm8) vmovupd(xmm8, mem(rcx, 0*32)) add(rdi, rcx) - - + + vfmadd231pd(mem(rcx, 0*32), xmm3, xmm10) vmovupd(xmm10, mem(rcx, 0*32)) add(rdi, rcx) - - + + vfmadd231pd(mem(rcx, 0*32), xmm3, xmm12) vmovupd(xmm12, mem(rcx, 0*32)) add(rdi, rcx) - - + + vfmadd231pd(mem(rcx, 0*32), xmm3, xmm14) vmovupd(xmm14, mem(rcx, 0*32)) //add(rdi, rcx) - - + + jmp(.DDONE) // jump to end. @@ -2853,40 +9954,40 @@ void bli_dgemmsup_rv_haswell_asm_6x2m jmp(.DDONE) // jump to end. - - - - + + + + label(.DBETAZERO) cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. jz(.DCOLSTORBZ) // jump to column storage case - - + + label(.DROWSTORBZ) - - + + vmovupd(xmm4, mem(rcx, 0*32)) add(rdi, rcx) - + vmovupd(xmm6, mem(rcx, 0*32)) add(rdi, rcx) - - + + vmovupd(xmm8, mem(rcx, 0*32)) add(rdi, rcx) - + vmovupd(xmm10, mem(rcx, 0*32)) add(rdi, rcx) - - + + vmovupd(xmm12, mem(rcx, 0*32)) add(rdi, rcx) - + vmovupd(xmm14, mem(rcx, 0*32)) //add(rdi, rcx) @@ -2897,7 +9998,7 @@ void bli_dgemmsup_rv_haswell_asm_6x2m label(.DCOLSTORBZ) - + // begin I/O on columns 0-3 vunpcklpd(xmm6, xmm4, xmm0) vunpckhpd(xmm6, xmm4, xmm1) @@ -2918,10 +10019,10 @@ void bli_dgemmsup_rv_haswell_asm_6x2m vmovupd(xmm1, mem(rdx, rsi, 1)) //lea(mem(rdx, rsi, 4), rdx) - - - - + + + + label(.DDONE) @@ -2942,8 +10043,7 @@ void bli_dgemmsup_rv_haswell_asm_6x2m label(.DRETURN) - - + vzeroupper() end_asm( : // output operands (none) @@ -3060,3 +10160,4 @@ void bli_dgemmsup_rv_haswell_asm_6x2m AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7); } + diff --git a/kernels/haswell/3/sup/bli_gemmsup_rv_haswell_asm_s6x16m.c b/kernels/haswell/3/sup/bli_gemmsup_rv_haswell_asm_s6x16m.c index 426e5157e1..c299047ff9 100644 --- a/kernels/haswell/3/sup/bli_gemmsup_rv_haswell_asm_s6x16m.c +++ b/kernels/haswell/3/sup/bli_gemmsup_rv_haswell_asm_s6x16m.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2019, Advanced Micro Devices, Inc. + Copyright (C) 2022, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -4475,34 +4475,39 @@ void bli_sgemmsup_rv_haswell_asm_6x2m label(.SROWSTORED) - - - vfmadd231ps(mem(rcx, 0*32), xmm3, xmm4) + + vmovsd(mem(rcx, 0*32), xmm0) + vfmadd231ps(xmm0, xmm3, xmm4) vmovsd(xmm4, mem(rcx, 0*32)) add(rdi, rcx) - - - vfmadd231ps(mem(rcx, 0*32), xmm3, xmm6) + + + vmovsd(mem(rcx, 0*32), xmm0) + vfmadd231ps(xmm0, xmm3, xmm6) vmovsd(xmm6, mem(rcx, 0*32)) add(rdi, rcx) - - - vfmadd231ps(mem(rcx, 0*32), xmm3, xmm8) + + + vmovsd(mem(rcx, 0*32), xmm0) + vfmadd231ps(xmm0, xmm3, xmm8) vmovsd(xmm8, mem(rcx, 0*32)) add(rdi, rcx) - - - vfmadd231ps(mem(rcx, 0*32), xmm3, xmm10) + + + vmovsd(mem(rcx, 0*32), xmm0) + vfmadd231ps(xmm0, xmm3, xmm10) vmovsd(xmm10, mem(rcx, 0*32)) add(rdi, rcx) - - - vfmadd231ps(mem(rcx, 0*32), xmm3, xmm12) + + + vmovsd(mem(rcx, 0*32), xmm0) + vfmadd231ps(xmm0, xmm3, xmm12) vmovsd(xmm12, mem(rcx, 0*32)) add(rdi, rcx) - - - vfmadd231ps(mem(rcx, 0*32), xmm3, xmm14) + + + vmovsd(mem(rcx, 0*32), xmm0) + vfmadd231ps(xmm0, xmm3, xmm14) vmovsd(xmm14, mem(rcx, 0*32)) //add(rdi, rcx) diff --git a/kernels/haswell/3/sup/d6x8/bli_gemmsup_rd_haswell_asm_dMx1.c b/kernels/haswell/3/sup/d6x8/bli_gemmsup_rd_haswell_asm_dMx1.c index 6e3c1a0e85..8d3900f2e8 100644 --- a/kernels/haswell/3/sup/d6x8/bli_gemmsup_rd_haswell_asm_dMx1.c +++ b/kernels/haswell/3/sup/d6x8/bli_gemmsup_rd_haswell_asm_dMx1.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2019, Advanced Micro Devices, Inc. + Copyright (C) 2022, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -99,9 +99,9 @@ void bli_dgemmsup_rd_haswell_asm_6x1 // ------------------------------------------------------------------------- begin_asm() - + //vzeroall() // zero all xmm/ymm registers. - + mov(var(a), rax) // load address of a. mov(var(rs_a), r8) // load rs_a //mov(var(cs_a), r9) // load cs_a @@ -119,7 +119,7 @@ void bli_dgemmsup_rd_haswell_asm_6x1 //lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b //lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a - + mov(var(c), rcx) // load address of c mov(var(rs_c), rdi) // load rs_c @@ -163,19 +163,19 @@ void bli_dgemmsup_rd_haswell_asm_6x1 prefetch(0, mem(r10, rdi, 1, 1*8)) // prefetch c + 4*rs_c prefetch(0, mem(r10, rdi, 2, 1*8)) // prefetch c + 5*rs_c #endif - - - + + + mov(var(k_iter16), rsi) // i = k_iter16; test(rsi, rsi) // check i via logical AND. je(.DCONSIDKITER4) // if i == 0, jump to code that // contains the k_iter4 loop. - - + + label(.DLOOPKITER16) // MAIN LOOP - - + + // ---------------------------------- iteration 0 #if 0 @@ -206,7 +206,7 @@ void bli_dgemmsup_rd_haswell_asm_6x1 add(imm(4*8), rax) // a += 4*cs_a = 4*8; vfmadd231pd(ymm0, ymm3, ymm14) - + // ---------------------------------- iteration 1 vmovupd(mem(rbx ), ymm0) @@ -233,7 +233,7 @@ void bli_dgemmsup_rd_haswell_asm_6x1 // ---------------------------------- iteration 2 - + #if 0 prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a @@ -287,27 +287,27 @@ void bli_dgemmsup_rd_haswell_asm_6x1 add(imm(4*8), rax) // a += 4*cs_a = 4*8; vfmadd231pd(ymm0, ymm3, ymm14) - + dec(rsi) // i -= 1; jne(.DLOOPKITER16) // iterate again if i != 0. - - - - - - + + + + + + label(.DCONSIDKITER4) - + mov(var(k_iter4), rsi) // i = k_iter4; test(rsi, rsi) // check i via logical AND. je(.DCONSIDKLEFT1) // if i == 0, jump to code that // considers k_left1 loop. // else, we prepare to enter k_iter4 loop. - - + + label(.DLOOPKITER4) // EDGE LOOP (ymm) - + #if 0 prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a @@ -336,21 +336,21 @@ void bli_dgemmsup_rd_haswell_asm_6x1 add(imm(4*8), rax) // a += 4*cs_a = 4*8; vfmadd231pd(ymm0, ymm3, ymm14) - + dec(rsi) // i -= 1; jne(.DLOOPKITER4) // iterate again if i != 0. - - - + + + label(.DCONSIDKLEFT1) - + mov(var(k_left1), rsi) // i = k_left1; test(rsi, rsi) // check i via logical AND. je(.DPOSTACCUM) // if i == 0, we're done; jump to end. // else, we prepare to enter k_left1 loop. - - + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) @@ -358,7 +358,7 @@ void bli_dgemmsup_rd_haswell_asm_6x1 // using the xmm registers would zero out the // high bits of the destination registers, // which would destory intermediate results. - + vmovsd(mem(rbx ), xmm0) add(imm(1*8), rbx) // b += 1*rs_b = 1*8; @@ -381,12 +381,12 @@ void bli_dgemmsup_rd_haswell_asm_6x1 add(imm(1*8), rax) // a += 1*cs_a = 1*8; vfmadd231pd(ymm0, ymm3, ymm14) - + dec(rsi) // i -= 1; jne(.DLOOPKLEFT1) // iterate again if i != 0. - - - + + + @@ -399,28 +399,28 @@ void bli_dgemmsup_rd_haswell_asm_6x1 // ymm10 // ymm12 // ymm14 - - vhaddpd( ymm5, ymm4, ymm0 ) + + vhaddpd( ymm4, ymm4, ymm0 ) vextractf128(imm(1), ymm0, xmm1 ) vaddpd( xmm0, xmm1, xmm4 ) - vhaddpd( ymm7, ymm6, ymm0 ) + vhaddpd( ymm6, ymm6, ymm0 ) vextractf128(imm(1), ymm0, xmm1 ) vaddpd( xmm0, xmm1, xmm6 ) - vhaddpd( ymm9, ymm8, ymm0 ) + vhaddpd( ymm8, ymm8, ymm0 ) vextractf128(imm(1), ymm0, xmm1 ) vaddpd( xmm0, xmm1, xmm8 ) - vhaddpd( ymm11, ymm10, ymm0 ) + vhaddpd( ymm10, ymm10, ymm0 ) vextractf128(imm(1), ymm0, xmm1 ) vaddpd( xmm0, xmm1, xmm10 ) - vhaddpd( ymm13, ymm12, ymm0 ) + vhaddpd( ymm12, ymm12, ymm0 ) vextractf128(imm(1), ymm0, xmm1 ) vaddpd( xmm0, xmm1, xmm12 ) - vhaddpd( ymm15, ymm14, ymm0 ) + vhaddpd( ymm14, ymm14, ymm0 ) vextractf128(imm(1), ymm0, xmm1 ) vaddpd( xmm0, xmm1, xmm14 ) @@ -435,114 +435,114 @@ void bli_dgemmsup_rd_haswell_asm_6x1 //mov(var(rs_c), rdi) // load rs_c //lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(double) - + mov(var(alpha), rax) // load address of alpha mov(var(beta), rbx) // load address of beta vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate - + vmulpd(xmm0, xmm4, xmm4) // scale by alpha vmulpd(xmm0, xmm6, xmm6) vmulpd(xmm0, xmm8, xmm8) vmulpd(xmm0, xmm10, xmm10) vmulpd(xmm0, xmm12, xmm12) vmulpd(xmm0, xmm14, xmm14) - - - - - - + + + + + + //mov(var(cs_c), rsi) // load cs_c //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) - - - + + + // now avoid loading C if beta == 0 - + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. vucomisd(xmm0, xmm3) // set ZF if beta == 0. je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case - - + + label(.DROWSTORED) - - vmovsd(mem(rcx), xmm0) + + vmovsd(mem(rcx), xmm0) vfmadd231pd(xmm0, xmm3, xmm4) vmovsd(xmm4, mem(rcx)) add(rdi, rcx) - - vmovsd(mem(rcx), xmm0) + + vmovsd(mem(rcx), xmm0) vfmadd231pd(xmm0, xmm3, xmm6) vmovsd(xmm6, mem(rcx)) add(rdi, rcx) - - vmovsd(mem(rcx), xmm0) + + vmovsd(mem(rcx), xmm0) vfmadd231pd(xmm0, xmm3, xmm8) vmovsd(xmm8, mem(rcx)) add(rdi, rcx) - - vmovsd(mem(rcx), xmm0) + + vmovsd(mem(rcx), xmm0) vfmadd231pd(xmm0, xmm3, xmm10) vmovsd(xmm10, mem(rcx)) add(rdi, rcx) - - vmovsd(mem(rcx), xmm0) + + vmovsd(mem(rcx), xmm0) vfmadd231pd(xmm0, xmm3, xmm12) vmovsd(xmm12, mem(rcx)) add(rdi, rcx) - - vmovsd(mem(rcx), xmm0) + + vmovsd(mem(rcx), xmm0) vfmadd231pd(xmm0, xmm3, xmm14) vmovsd(xmm14, mem(rcx)) //add(rdi, rcx) - - - + + + jmp(.DDONE) // jump to end. - - - - + + + + label(.DBETAZERO) - - + + label(.DROWSTORBZ) - - + + vmovsd(xmm4, mem(rcx)) add(rdi, rcx) - + vmovsd(xmm6, mem(rcx)) add(rdi, rcx) - + vmovsd(xmm8, mem(rcx)) add(rdi, rcx) - + vmovsd(xmm10, mem(rcx)) add(rdi, rcx) - + vmovsd(xmm12, mem(rcx)) add(rdi, rcx) - + vmovsd(xmm14, mem(rcx)) //add(rdi, rcx) - - - - + + + + label(.DDONE) - + label(.DRETURN) - + end_asm( : // output operands (none) @@ -613,9 +613,9 @@ void bli_dgemmsup_rd_haswell_asm_3x1 // ------------------------------------------------------------------------- begin_asm() - + //vzeroall() // zero all xmm/ymm registers. - + mov(var(a), rax) // load address of a. mov(var(rs_a), r8) // load rs_a //mov(var(cs_a), r9) // load cs_a @@ -633,7 +633,7 @@ void bli_dgemmsup_rd_haswell_asm_3x1 //lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b //lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a - + mov(var(c), rcx) // load address of c mov(var(rs_c), rdi) // load rs_c @@ -671,19 +671,19 @@ void bli_dgemmsup_rd_haswell_asm_3x1 prefetch(0, mem(rcx, rdi, 1, 1*8)) // prefetch c + 1*rs_c prefetch(0, mem(rcx, rdi, 2, 1*8)) // prefetch c + 2*rs_c #endif - - - + + + mov(var(k_iter16), rsi) // i = k_iter16; test(rsi, rsi) // check i via logical AND. je(.DCONSIDKITER4) // if i == 0, jump to code that // contains the k_iter4 loop. - - + + label(.DLOOPKITER16) // MAIN LOOP - - + + // ---------------------------------- iteration 0 #if 0 @@ -705,7 +705,7 @@ void bli_dgemmsup_rd_haswell_asm_3x1 add(imm(4*8), rax) // a += 4*cs_a = 4*8; vfmadd231pd(ymm0, ymm3, ymm8) - + // ---------------------------------- iteration 1 vmovupd(mem(rbx ), ymm0) @@ -723,7 +723,7 @@ void bli_dgemmsup_rd_haswell_asm_3x1 // ---------------------------------- iteration 2 - + #if 0 prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a @@ -759,27 +759,27 @@ void bli_dgemmsup_rd_haswell_asm_3x1 add(imm(4*8), rax) // a += 4*cs_a = 4*8; vfmadd231pd(ymm0, ymm3, ymm8) - + dec(rsi) // i -= 1; jne(.DLOOPKITER16) // iterate again if i != 0. - - - - - - + + + + + + label(.DCONSIDKITER4) - + mov(var(k_iter4), rsi) // i = k_iter4; test(rsi, rsi) // check i via logical AND. je(.DCONSIDKLEFT1) // if i == 0, jump to code that // considers k_left1 loop. // else, we prepare to enter k_iter4 loop. - - + + label(.DLOOPKITER4) // EDGE LOOP (ymm) - + #if 0 prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a @@ -799,21 +799,21 @@ void bli_dgemmsup_rd_haswell_asm_3x1 add(imm(4*8), rax) // a += 4*cs_a = 4*8; vfmadd231pd(ymm0, ymm3, ymm8) - + dec(rsi) // i -= 1; jne(.DLOOPKITER4) // iterate again if i != 0. - - - + + + label(.DCONSIDKLEFT1) - + mov(var(k_left1), rsi) // i = k_left1; test(rsi, rsi) // check i via logical AND. je(.DPOSTACCUM) // if i == 0, we're done; jump to end. // else, we prepare to enter k_left1 loop. - - + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) @@ -821,7 +821,7 @@ void bli_dgemmsup_rd_haswell_asm_3x1 // using the xmm registers would zero out the // high bits of the destination registers, // which would destory intermediate results. - + vmovsd(mem(rbx ), xmm0) add(imm(1*8), rbx) // b += 1*rs_b = 1*8; @@ -835,12 +835,12 @@ void bli_dgemmsup_rd_haswell_asm_3x1 add(imm(1*8), rax) // a += 1*cs_a = 1*8; vfmadd231pd(ymm0, ymm3, ymm8) - + dec(rsi) // i -= 1; jne(.DLOOPKLEFT1) // iterate again if i != 0. - - - + + + @@ -850,16 +850,16 @@ void bli_dgemmsup_rd_haswell_asm_3x1 // ymm4 // ymm6 // ymm8 - - vhaddpd( ymm5, ymm4, ymm0 ) + + vhaddpd( ymm4, ymm4, ymm0 ) vextractf128(imm(1), ymm0, xmm1 ) vaddpd( xmm0, xmm1, xmm4 ) - vhaddpd( ymm7, ymm6, ymm0 ) + vhaddpd( ymm6, ymm6, ymm0 ) vextractf128(imm(1), ymm0, xmm1 ) vaddpd( xmm0, xmm1, xmm6 ) - vhaddpd( ymm9, ymm8, ymm0 ) + vhaddpd( ymm8, ymm8, ymm0 ) vextractf128(imm(1), ymm0, xmm1 ) vaddpd( xmm0, xmm1, xmm8 ) @@ -871,87 +871,87 @@ void bli_dgemmsup_rd_haswell_asm_3x1 //mov(var(rs_c), rdi) // load rs_c //lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(double) - + mov(var(alpha), rax) // load address of alpha mov(var(beta), rbx) // load address of beta vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate - + vmulpd(xmm0, xmm4, xmm4) // scale by alpha vmulpd(xmm0, xmm6, xmm6) vmulpd(xmm0, xmm8, xmm8) - - - - - - + + + + + + //mov(var(cs_c), rsi) // load cs_c //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) - - - + + + // now avoid loading C if beta == 0 - + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. vucomisd(xmm0, xmm3) // set ZF if beta == 0. je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case - - + + label(.DROWSTORED) - - vmovsd(mem(rcx), xmm0) + + vmovsd(mem(rcx), xmm0) vfmadd231pd(xmm0, xmm3, xmm4) vmovsd(xmm4, mem(rcx)) add(rdi, rcx) - - vmovsd(mem(rcx), xmm0) + + vmovsd(mem(rcx), xmm0) vfmadd231pd(xmm0, xmm3, xmm6) vmovsd(xmm6, mem(rcx)) add(rdi, rcx) - - vmovsd(mem(rcx), xmm0) + + vmovsd(mem(rcx), xmm0) vfmadd231pd(xmm0, xmm3, xmm8) vmovsd(xmm8, mem(rcx)) //add(rdi, rcx) - - - + + + jmp(.DDONE) // jump to end. - - - - + + + + label(.DBETAZERO) - - + + label(.DROWSTORBZ) - - + + vmovsd(xmm4, mem(rcx)) add(rdi, rcx) - + vmovsd(xmm6, mem(rcx)) add(rdi, rcx) - + vmovsd(xmm8, mem(rcx)) //add(rdi, rcx) - - - - + + + + label(.DDONE) - + label(.DRETURN) - + end_asm( : // output operands (none) @@ -1022,9 +1022,9 @@ void bli_dgemmsup_rd_haswell_asm_2x1 // ------------------------------------------------------------------------- begin_asm() - + //vzeroall() // zero all xmm/ymm registers. - + mov(var(a), rax) // load address of a. mov(var(rs_a), r8) // load rs_a //mov(var(cs_a), r9) // load cs_a @@ -1042,7 +1042,7 @@ void bli_dgemmsup_rd_haswell_asm_2x1 //lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b //lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a - + mov(var(c), rcx) // load address of c mov(var(rs_c), rdi) // load rs_c @@ -1078,19 +1078,19 @@ void bli_dgemmsup_rd_haswell_asm_2x1 prefetch(0, mem(rcx, 1*8)) // prefetch c + 0*rs_c prefetch(0, mem(rcx, rdi, 1, 1*8)) // prefetch c + 1*rs_c #endif - - - + + + mov(var(k_iter16), rsi) // i = k_iter16; test(rsi, rsi) // check i via logical AND. je(.DCONSIDKITER4) // if i == 0, jump to code that // contains the k_iter4 loop. - - + + label(.DLOOPKITER16) // MAIN LOOP - - + + // ---------------------------------- iteration 0 #if 0 @@ -1109,7 +1109,7 @@ void bli_dgemmsup_rd_haswell_asm_2x1 add(imm(4*8), rax) // a += 4*cs_a = 4*8; vfmadd231pd(ymm0, ymm3, ymm6) - + // ---------------------------------- iteration 1 vmovupd(mem(rbx ), ymm0) @@ -1124,7 +1124,7 @@ void bli_dgemmsup_rd_haswell_asm_2x1 // ---------------------------------- iteration 2 - + #if 0 prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a @@ -1154,27 +1154,27 @@ void bli_dgemmsup_rd_haswell_asm_2x1 add(imm(4*8), rax) // a += 4*cs_a = 4*8; vfmadd231pd(ymm0, ymm3, ymm6) - + dec(rsi) // i -= 1; jne(.DLOOPKITER16) // iterate again if i != 0. - - - - - - + + + + + + label(.DCONSIDKITER4) - + mov(var(k_iter4), rsi) // i = k_iter4; test(rsi, rsi) // check i via logical AND. je(.DCONSIDKLEFT1) // if i == 0, jump to code that // considers k_left1 loop. // else, we prepare to enter k_iter4 loop. - - + + label(.DLOOPKITER4) // EDGE LOOP (ymm) - + #if 0 prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a @@ -1191,21 +1191,21 @@ void bli_dgemmsup_rd_haswell_asm_2x1 add(imm(4*8), rax) // a += 4*cs_a = 4*8; vfmadd231pd(ymm0, ymm3, ymm6) - + dec(rsi) // i -= 1; jne(.DLOOPKITER4) // iterate again if i != 0. - - - + + + label(.DCONSIDKLEFT1) - + mov(var(k_left1), rsi) // i = k_left1; test(rsi, rsi) // check i via logical AND. je(.DPOSTACCUM) // if i == 0, we're done; jump to end. // else, we prepare to enter k_left1 loop. - - + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) @@ -1213,7 +1213,7 @@ void bli_dgemmsup_rd_haswell_asm_2x1 // using the xmm registers would zero out the // high bits of the destination registers, // which would destory intermediate results. - + vmovsd(mem(rbx ), xmm0) add(imm(1*8), rbx) // b += 1*rs_b = 1*8; @@ -1224,12 +1224,12 @@ void bli_dgemmsup_rd_haswell_asm_2x1 add(imm(1*8), rax) // a += 1*cs_a = 1*8; vfmadd231pd(ymm0, ymm3, ymm6) - + dec(rsi) // i -= 1; jne(.DLOOPKLEFT1) // iterate again if i != 0. - - - + + + @@ -1238,12 +1238,12 @@ void bli_dgemmsup_rd_haswell_asm_2x1 // ymm4 // ymm6 - - vhaddpd( ymm5, ymm4, ymm0 ) + + vhaddpd( ymm4, ymm4, ymm0 ) vextractf128(imm(1), ymm0, xmm1 ) vaddpd( xmm0, xmm1, xmm4 ) - vhaddpd( ymm7, ymm6, ymm0 ) + vhaddpd( ymm6, ymm6, ymm0 ) vextractf128(imm(1), ymm0, xmm1 ) vaddpd( xmm0, xmm1, xmm6 ) @@ -1254,78 +1254,78 @@ void bli_dgemmsup_rd_haswell_asm_2x1 //mov(var(rs_c), rdi) // load rs_c //lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(double) - + mov(var(alpha), rax) // load address of alpha mov(var(beta), rbx) // load address of beta vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate - + vmulpd(xmm0, xmm4, xmm4) // scale by alpha vmulpd(xmm0, xmm6, xmm6) - - - - - - + + + + + + //mov(var(cs_c), rsi) // load cs_c //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) - - - + + + // now avoid loading C if beta == 0 - + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. vucomisd(xmm0, xmm3) // set ZF if beta == 0. je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case - - + + label(.DROWSTORED) - - vmovsd(mem(rcx), xmm0) + + vmovsd(mem(rcx), xmm0) vfmadd231pd(xmm0, xmm3, xmm4) vmovsd(xmm4, mem(rcx)) add(rdi, rcx) - - vmovsd(mem(rcx), xmm0) + + vmovsd(mem(rcx), xmm0) vfmadd231pd(xmm0, xmm3, xmm6) vmovsd(xmm6, mem(rcx)) //add(rdi, rcx) - - - + + + jmp(.DDONE) // jump to end. - - - - + + + + label(.DBETAZERO) - - + + label(.DROWSTORBZ) - - + + vmovsd(xmm4, mem(rcx)) add(rdi, rcx) - + vmovsd(xmm6, mem(rcx)) //add(rdi, rcx) - - - - + + + + label(.DDONE) - + label(.DRETURN) - + end_asm( : // output operands (none) @@ -1396,9 +1396,9 @@ void bli_dgemmsup_rd_haswell_asm_1x1 // ------------------------------------------------------------------------- begin_asm() - + //vzeroall() // zero all xmm/ymm registers. - + mov(var(a), rax) // load address of a. mov(var(rs_a), r8) // load rs_a //mov(var(cs_a), r9) // load cs_a @@ -1416,7 +1416,7 @@ void bli_dgemmsup_rd_haswell_asm_1x1 //lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b //lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a - + mov(var(c), rcx) // load address of c mov(var(rs_c), rdi) // load rs_c @@ -1450,19 +1450,19 @@ void bli_dgemmsup_rd_haswell_asm_1x1 //lea(mem(r10, rdi, 1), r10) // rdx = c + 3*rs_c; prefetch(0, mem(rcx, 1*8)) // prefetch c + 0*rs_c #endif - - - + + + mov(var(k_iter16), rsi) // i = k_iter16; test(rsi, rsi) // check i via logical AND. je(.DCONSIDKITER4) // if i == 0, jump to code that // contains the k_iter4 loop. - - + + label(.DLOOPKITER16) // MAIN LOOP - - + + // ---------------------------------- iteration 0 #if 0 @@ -1478,7 +1478,7 @@ void bli_dgemmsup_rd_haswell_asm_1x1 add(imm(4*8), rax) // a += 4*cs_a = 4*8; vfmadd231pd(ymm0, ymm3, ymm4) - + // ---------------------------------- iteration 1 vmovupd(mem(rbx ), ymm0) @@ -1490,7 +1490,7 @@ void bli_dgemmsup_rd_haswell_asm_1x1 // ---------------------------------- iteration 2 - + #if 0 prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a @@ -1514,27 +1514,27 @@ void bli_dgemmsup_rd_haswell_asm_1x1 add(imm(4*8), rax) // a += 4*cs_a = 4*8; vfmadd231pd(ymm0, ymm3, ymm4) - + dec(rsi) // i -= 1; jne(.DLOOPKITER16) // iterate again if i != 0. - - - - - - + + + + + + label(.DCONSIDKITER4) - + mov(var(k_iter4), rsi) // i = k_iter4; test(rsi, rsi) // check i via logical AND. je(.DCONSIDKLEFT1) // if i == 0, jump to code that // considers k_left1 loop. // else, we prepare to enter k_iter4 loop. - - + + label(.DLOOPKITER4) // EDGE LOOP (ymm) - + #if 0 prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a @@ -1548,21 +1548,21 @@ void bli_dgemmsup_rd_haswell_asm_1x1 add(imm(4*8), rax) // a += 4*cs_a = 4*8; vfmadd231pd(ymm0, ymm3, ymm4) - + dec(rsi) // i -= 1; jne(.DLOOPKITER4) // iterate again if i != 0. - - - + + + label(.DCONSIDKLEFT1) - + mov(var(k_left1), rsi) // i = k_left1; test(rsi, rsi) // check i via logical AND. je(.DPOSTACCUM) // if i == 0, we're done; jump to end. // else, we prepare to enter k_left1 loop. - - + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) @@ -1570,7 +1570,7 @@ void bli_dgemmsup_rd_haswell_asm_1x1 // using the xmm registers would zero out the // high bits of the destination registers, // which would destory intermediate results. - + vmovsd(mem(rbx ), xmm0) add(imm(1*8), rbx) // b += 1*rs_b = 1*8; @@ -1578,12 +1578,12 @@ void bli_dgemmsup_rd_haswell_asm_1x1 add(imm(1*8), rax) // a += 1*cs_a = 1*8; vfmadd231pd(ymm0, ymm3, ymm4) - + dec(rsi) // i -= 1; jne(.DLOOPKLEFT1) // iterate again if i != 0. - - - + + + @@ -1591,8 +1591,8 @@ void bli_dgemmsup_rd_haswell_asm_1x1 label(.DPOSTACCUM) // ymm4 - - vhaddpd( ymm5, ymm4, ymm0 ) + + vhaddpd( ymm4, ymm4, ymm0 ) vextractf128(imm(1), ymm0, xmm1 ) vaddpd( xmm0, xmm1, xmm4 ) @@ -1602,69 +1602,69 @@ void bli_dgemmsup_rd_haswell_asm_1x1 //mov(var(rs_c), rdi) // load rs_c //lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(double) - + mov(var(alpha), rax) // load address of alpha mov(var(beta), rbx) // load address of beta vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate - + vmulpd(xmm0, xmm4, xmm4) // scale by alpha - - - - - - + + + + + + //mov(var(cs_c), rsi) // load cs_c //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) - - - + + + // now avoid loading C if beta == 0 - + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. vucomisd(xmm0, xmm3) // set ZF if beta == 0. je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case - - + + label(.DROWSTORED) - - vmovsd(mem(rcx), xmm0) + + vmovsd(mem(rcx), xmm0) vfmadd231pd(xmm0, xmm3, xmm4) vmovsd(xmm4, mem(rcx)) //add(rdi, rcx) - - - + + + jmp(.DDONE) // jump to end. - - - - + + + + label(.DBETAZERO) - - + + label(.DROWSTORBZ) - - + + vmovsd(xmm4, mem(rcx)) //add(rdi, rcx) - - - - + + + + label(.DDONE) - + label(.DRETURN) - + end_asm( : // output operands (none) diff --git a/kernels/haswell/3/sup/d6x8/bli_gemmsup_rd_haswell_asm_dMx4.c b/kernels/haswell/3/sup/d6x8/bli_gemmsup_rd_haswell_asm_dMx4.c index 4c6094b1cd..f19b703b41 100644 --- a/kernels/haswell/3/sup/d6x8/bli_gemmsup_rd_haswell_asm_dMx4.c +++ b/kernels/haswell/3/sup/d6x8/bli_gemmsup_rd_haswell_asm_dMx4.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2019, Advanced Micro Devices, Inc. + Copyright (C) 2022, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -101,7 +101,7 @@ void bli_dgemmsup_rd_haswell_asm_6x4 begin_asm() //vzeroall() // zero all xmm/ymm registers. - + mov(var(a), r14) // load address of a. mov(var(rs_a), r8) // load rs_a //mov(var(cs_a), r9) // load cs_a @@ -119,7 +119,7 @@ void bli_dgemmsup_rd_haswell_asm_6x4 lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a - + mov(var(c), r12) // load address of c mov(var(rs_c), rdi) // load rs_c @@ -172,19 +172,19 @@ void bli_dgemmsup_rd_haswell_asm_6x4 prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c #endif lea(mem(r8, r8, 4), rbp) // rbp = 5*rs_a - - - + + + mov(var(k_iter16), rsi) // i = k_iter16; test(rsi, rsi) // check i via logical AND. je(.DCONSIDKITER4) // if i == 0, jump to code that // contains the k_iter4 loop. - - + + label(.DLOOPKITER16) // MAIN LOOP - - + + // ---------------------------------- iteration 0 #if 0 @@ -219,7 +219,7 @@ void bli_dgemmsup_rd_haswell_asm_6x4 vfmadd231pd(ymm1, ymm3, ymm14) vfmadd231pd(ymm2, ymm3, ymm15) - + // ---------------------------------- iteration 1 vmovupd(mem(rax ), ymm0) @@ -250,7 +250,7 @@ void bli_dgemmsup_rd_haswell_asm_6x4 // ---------------------------------- iteration 2 - + #if 0 prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a @@ -312,27 +312,27 @@ void bli_dgemmsup_rd_haswell_asm_6x4 vfmadd231pd(ymm1, ymm3, ymm14) vfmadd231pd(ymm2, ymm3, ymm15) - + dec(rsi) // i -= 1; jne(.DLOOPKITER16) // iterate again if i != 0. - - - - - - + + + + + + label(.DCONSIDKITER4) - + mov(var(k_iter4), rsi) // i = k_iter4; test(rsi, rsi) // check i via logical AND. je(.DCONSIDKLEFT1) // if i == 0, jump to code that // considers k_left1 loop. // else, we prepare to enter k_iter4 loop. - - + + label(.DLOOPKITER4) // EDGE LOOP (ymm) - + #if 0 prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a @@ -343,7 +343,7 @@ void bli_dgemmsup_rd_haswell_asm_6x4 vmovupd(mem(rax, r8, 1), ymm1) vmovupd(mem(rax, r8, 2), ymm2) add(imm(4*8), rax) // a += 4*cs_b = 4*8; - + vmovupd(mem(rbx ), ymm3) vfmadd231pd(ymm0, ymm3, ymm4) vfmadd231pd(ymm1, ymm3, ymm5) @@ -365,21 +365,21 @@ void bli_dgemmsup_rd_haswell_asm_6x4 vfmadd231pd(ymm1, ymm3, ymm14) vfmadd231pd(ymm2, ymm3, ymm15) - + dec(rsi) // i -= 1; jne(.DLOOPKITER4) // iterate again if i != 0. - - - + + + label(.DCONSIDKLEFT1) - + mov(var(k_left1), rsi) // i = k_left1; test(rsi, rsi) // check i via logical AND. je(.DPOSTACCUM) // if i == 0, we're done; jump to end. // else, we prepare to enter k_left1 loop. - - + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) @@ -387,12 +387,12 @@ void bli_dgemmsup_rd_haswell_asm_6x4 // using the xmm registers would zero out the // high bits of the destination registers, // which would destory intermediate results. - + vmovsd(mem(rax ), xmm0) vmovsd(mem(rax, r8, 1), xmm1) vmovsd(mem(rax, r8, 2), xmm2) add(imm(1*8), rax) // a += 1*cs_a = 1*8; - + vmovsd(mem(rbx ), xmm3) vfmadd231pd(ymm0, ymm3, ymm4) vfmadd231pd(ymm1, ymm3, ymm5) @@ -414,12 +414,12 @@ void bli_dgemmsup_rd_haswell_asm_6x4 vfmadd231pd(ymm1, ymm3, ymm14) vfmadd231pd(ymm2, ymm3, ymm15) - + dec(rsi) // i -= 1; jne(.DLOOPKLEFT1) // iterate again if i != 0. - - - + + + @@ -427,11 +427,11 @@ void bli_dgemmsup_rd_haswell_asm_6x4 label(.DPOSTACCUM) - - // ymm4 ymm7 ymm10 ymm13 + + // ymm4 ymm7 ymm10 ymm13 // ymm5 ymm8 ymm11 ymm14 // ymm6 ymm9 ymm12 ymm15 - + vhaddpd( ymm7, ymm4, ymm0 ) vextractf128(imm(1), ymm0, xmm1 ) vaddpd( xmm0, xmm1, xmm0 ) @@ -469,7 +469,7 @@ void bli_dgemmsup_rd_haswell_asm_6x4 // xmm6[0:3] = sum(ymm6) sum(ymm9) sum(ymm12) sum(ymm15) - + //mov(var(rs_c), rdi) // load rs_c //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) @@ -477,73 +477,73 @@ void bli_dgemmsup_rd_haswell_asm_6x4 mov(var(beta), rbx) // load address of beta vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate - + vmulpd(ymm0, ymm4, ymm4) // scale by alpha vmulpd(ymm0, ymm5, ymm5) vmulpd(ymm0, ymm6, ymm6) - - - - - - + + + + + + //mov(var(cs_c), rsi) // load cs_c //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) - - - + + + // now avoid loading C if beta == 0 - + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. vucomisd(xmm0, xmm3) // set ZF if beta == 0. je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case - - + + label(.DROWSTORED) - - + + vfmadd231pd(mem(rcx), ymm3, ymm4) vmovupd(ymm4, mem(rcx)) add(rdi, rcx) - + vfmadd231pd(mem(rcx), ymm3, ymm5) vmovupd(ymm5, mem(rcx)) add(rdi, rcx) - + vfmadd231pd(mem(rcx), ymm3, ymm6) vmovupd(ymm6, mem(rcx)) //add(rdi, rcx) - - - + + + jmp(.DDONE) // jump to end. - - - - + + + + label(.DBETAZERO) - - + + label(.DROWSTORBZ) - - + + vmovupd(ymm4, mem(rcx)) add(rdi, rcx) - + vmovupd(ymm5, mem(rcx)) add(rdi, rcx) - + vmovupd(ymm6, mem(rcx)) //add(rdi, rcx) - - - - + + + + label(.DDONE) - - + + lea(mem(r12, rdi, 2), r12) // @@ -560,7 +560,7 @@ void bli_dgemmsup_rd_haswell_asm_6x4 label(.DRETURN) - + end_asm( : // output operands (none) @@ -629,7 +629,7 @@ void bli_dgemmsup_rd_haswell_asm_2x4 // ------------------------------------------------------------------------- begin_asm() - + //vzeroall() // zero all xmm/ymm registers. mov(var(a), rax) // load address of a. @@ -649,7 +649,7 @@ void bli_dgemmsup_rd_haswell_asm_2x4 lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b //lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a - + mov(var(c), rcx) // load address of c mov(var(rs_c), rdi) // load rs_c @@ -682,7 +682,7 @@ void bli_dgemmsup_rd_haswell_asm_2x4 //lea(mem(r14), rax) // rax = a; //lea(mem(rdx), rbx) // rbx = b; - + #if 1 //mov(var(rs_c), rdi) // load rs_c //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) @@ -690,18 +690,18 @@ void bli_dgemmsup_rd_haswell_asm_2x4 prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c #endif - - - + + + mov(var(k_iter16), rsi) // i = k_iter16; test(rsi, rsi) // check i via logical AND. je(.DCONSIDKITER4) // if i == 0, jump to code that // contains the k_iter4 loop. - - + + label(.DLOOPKITER16) // MAIN LOOP - - + + // ---------------------------------- iteration 0 #if 0 @@ -730,7 +730,7 @@ void bli_dgemmsup_rd_haswell_asm_2x4 vfmadd231pd(ymm0, ymm3, ymm13) vfmadd231pd(ymm1, ymm3, ymm14) - + // ---------------------------------- iteration 1 vmovupd(mem(rax ), ymm0) @@ -756,7 +756,7 @@ void bli_dgemmsup_rd_haswell_asm_2x4 // ---------------------------------- iteration 2 - + #if 0 prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a @@ -807,27 +807,27 @@ void bli_dgemmsup_rd_haswell_asm_2x4 vfmadd231pd(ymm0, ymm3, ymm13) vfmadd231pd(ymm1, ymm3, ymm14) - + dec(rsi) // i -= 1; jne(.DLOOPKITER16) // iterate again if i != 0. - - - - - - + + + + + + label(.DCONSIDKITER4) - + mov(var(k_iter4), rsi) // i = k_iter4; test(rsi, rsi) // check i via logical AND. je(.DCONSIDKLEFT1) // if i == 0, jump to code that // considers k_left1 loop. // else, we prepare to enter k_iter4 loop. - - + + label(.DLOOPKITER4) // EDGE LOOP (ymm) - + #if 0 prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a @@ -836,7 +836,7 @@ void bli_dgemmsup_rd_haswell_asm_2x4 vmovupd(mem(rax ), ymm0) vmovupd(mem(rax, r8, 1), ymm1) add(imm(4*8), rax) // a += 4*cs_b = 4*8; - + vmovupd(mem(rbx ), ymm3) vfmadd231pd(ymm0, ymm3, ymm4) vfmadd231pd(ymm1, ymm3, ymm5) @@ -854,21 +854,21 @@ void bli_dgemmsup_rd_haswell_asm_2x4 vfmadd231pd(ymm0, ymm3, ymm13) vfmadd231pd(ymm1, ymm3, ymm14) - + dec(rsi) // i -= 1; jne(.DLOOPKITER4) // iterate again if i != 0. - - - + + + label(.DCONSIDKLEFT1) - + mov(var(k_left1), rsi) // i = k_left1; test(rsi, rsi) // check i via logical AND. je(.DPOSTACCUM) // if i == 0, we're done; jump to end. // else, we prepare to enter k_left1 loop. - - + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) @@ -876,11 +876,11 @@ void bli_dgemmsup_rd_haswell_asm_2x4 // using the xmm registers would zero out the // high bits of the destination registers, // which would destory intermediate results. - + vmovsd(mem(rax ), xmm0) vmovsd(mem(rax, r8, 1), xmm1) add(imm(1*8), rax) // a += 1*cs_a = 1*8; - + vmovsd(mem(rbx ), xmm3) vfmadd231pd(ymm0, ymm3, ymm4) vfmadd231pd(ymm1, ymm3, ymm5) @@ -898,12 +898,12 @@ void bli_dgemmsup_rd_haswell_asm_2x4 vfmadd231pd(ymm0, ymm3, ymm13) vfmadd231pd(ymm1, ymm3, ymm14) - + dec(rsi) // i -= 1; jne(.DLOOPKLEFT1) // iterate again if i != 0. - - - + + + @@ -911,10 +911,10 @@ void bli_dgemmsup_rd_haswell_asm_2x4 label(.DPOSTACCUM) - - // ymm4 ymm7 ymm10 ymm13 + + // ymm4 ymm7 ymm10 ymm13 // ymm5 ymm8 ymm11 ymm14 - + vhaddpd( ymm7, ymm4, ymm0 ) vextractf128(imm(1), ymm0, xmm1 ) vaddpd( xmm0, xmm1, xmm0 ) @@ -943,75 +943,75 @@ void bli_dgemmsup_rd_haswell_asm_2x4 //mov(var(rs_c), rdi) // load rs_c //lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) - + mov(var(alpha), rax) // load address of alpha mov(var(beta), rbx) // load address of beta vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate - + vmulpd(ymm0, ymm4, ymm4) // scale by alpha vmulpd(ymm0, ymm5, ymm5) - - - - - - + + + + + + //mov(var(cs_c), rsi) // load cs_c //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) - - - + + + // now avoid loading C if beta == 0 - + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. vucomisd(xmm0, xmm3) // set ZF if beta == 0. je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case - - + + label(.DROWSTORED) - - + + vfmadd231pd(mem(rcx), ymm3, ymm4) vmovupd(ymm4, mem(rcx)) add(rdi, rcx) - + vfmadd231pd(mem(rcx), ymm3, ymm5) vmovupd(ymm5, mem(rcx)) //add(rdi, rcx) - - - + + + jmp(.DDONE) // jump to end. - - - - + + + + label(.DBETAZERO) - - + + label(.DROWSTORBZ) - - + + vmovupd(ymm4, mem(rcx)) add(rdi, rcx) - + vmovupd(ymm5, mem(rcx)) //add(rdi, rcx) - - - - + + + + label(.DDONE) label(.DRETURN) - - + + end_asm( : // output operands (none) @@ -1079,7 +1079,7 @@ void bli_dgemmsup_rd_haswell_asm_1x4 // ------------------------------------------------------------------------- begin_asm() - + //vzeroall() // zero all xmm/ymm registers. mov(var(a), rax) // load address of a. @@ -1099,7 +1099,7 @@ void bli_dgemmsup_rd_haswell_asm_1x4 lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b //lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a - + mov(var(c), rcx) // load address of c mov(var(rs_c), rdi) // load rs_c @@ -1128,26 +1128,26 @@ void bli_dgemmsup_rd_haswell_asm_1x4 //lea(mem(r14), rax) // rax = a; //lea(mem(rdx), rbx) // rbx = b; - + #if 1 //mov(var(rs_c), rdi) // load rs_c //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c - prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + //prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c #endif - - - + + + mov(var(k_iter16), rsi) // i = k_iter16; test(rsi, rsi) // check i via logical AND. je(.DCONSIDKITER4) // if i == 0, jump to code that // contains the k_iter4 loop. - - + + label(.DLOOPKITER16) // MAIN LOOP - - + + // ---------------------------------- iteration 0 #if 0 @@ -1170,7 +1170,7 @@ void bli_dgemmsup_rd_haswell_asm_1x4 add(imm(4*8), rbx) // b += 4*rs_b = 4*8; vfmadd231pd(ymm0, ymm3, ymm13) - + // ---------------------------------- iteration 1 vmovupd(mem(rax ), ymm0) @@ -1191,7 +1191,7 @@ void bli_dgemmsup_rd_haswell_asm_1x4 // ---------------------------------- iteration 2 - + #if 0 prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a #endif @@ -1231,27 +1231,27 @@ void bli_dgemmsup_rd_haswell_asm_1x4 add(imm(4*8), rbx) // b += 4*rs_b = 4*8; vfmadd231pd(ymm0, ymm3, ymm13) - + dec(rsi) // i -= 1; jne(.DLOOPKITER16) // iterate again if i != 0. - - - - - - + + + + + + label(.DCONSIDKITER4) - + mov(var(k_iter4), rsi) // i = k_iter4; test(rsi, rsi) // check i via logical AND. je(.DCONSIDKLEFT1) // if i == 0, jump to code that // considers k_left1 loop. // else, we prepare to enter k_iter4 loop. - - + + label(.DLOOPKITER4) // EDGE LOOP (ymm) - + #if 0 prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a @@ -1259,7 +1259,7 @@ void bli_dgemmsup_rd_haswell_asm_1x4 vmovupd(mem(rax ), ymm0) add(imm(4*8), rax) // a += 4*cs_b = 4*8; - + vmovupd(mem(rbx ), ymm3) vfmadd231pd(ymm0, ymm3, ymm4) @@ -1273,21 +1273,21 @@ void bli_dgemmsup_rd_haswell_asm_1x4 add(imm(4*8), rbx) // b += 4*rs_b = 4*8; vfmadd231pd(ymm0, ymm3, ymm13) - + dec(rsi) // i -= 1; jne(.DLOOPKITER4) // iterate again if i != 0. - - - + + + label(.DCONSIDKLEFT1) - + mov(var(k_left1), rsi) // i = k_left1; test(rsi, rsi) // check i via logical AND. je(.DPOSTACCUM) // if i == 0, we're done; jump to end. // else, we prepare to enter k_left1 loop. - - + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) @@ -1295,10 +1295,10 @@ void bli_dgemmsup_rd_haswell_asm_1x4 // using the xmm registers would zero out the // high bits of the destination registers, // which would destory intermediate results. - + vmovsd(mem(rax ), xmm0) add(imm(1*8), rax) // a += 1*cs_a = 1*8; - + vmovsd(mem(rbx ), xmm3) vfmadd231pd(ymm0, ymm3, ymm4) @@ -1312,12 +1312,12 @@ void bli_dgemmsup_rd_haswell_asm_1x4 add(imm(1*8), rbx) // b += 1*rs_b = 1*8; vfmadd231pd(ymm0, ymm3, ymm13) - + dec(rsi) // i -= 1; jne(.DLOOPKLEFT1) // iterate again if i != 0. - - - + + + @@ -1325,9 +1325,9 @@ void bli_dgemmsup_rd_haswell_asm_1x4 label(.DPOSTACCUM) - - // ymm4 ymm7 ymm10 ymm13 - + + // ymm4 ymm7 ymm10 ymm13 + vhaddpd( ymm7, ymm4, ymm0 ) vextractf128(imm(1), ymm0, xmm1 ) vaddpd( xmm0, xmm1, xmm0 ) @@ -1339,83 +1339,82 @@ void bli_dgemmsup_rd_haswell_asm_1x4 vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) - vhaddpd( ymm8, ymm5, ymm0 ) - vextractf128(imm(1), ymm0, xmm1 ) - vaddpd( xmm0, xmm1, xmm0 ) + //vhaddpd( ymm8, ymm5, ymm0 ) + //vextractf128(imm(1), ymm0, xmm1 ) + //vaddpd( xmm0, xmm1, xmm0 ) - vhaddpd( ymm14, ymm11, ymm2 ) - vextractf128(imm(1), ymm2, xmm1 ) - vaddpd( xmm2, xmm1, xmm2 ) + //vhaddpd( ymm14, ymm11, ymm2 ) + //vextractf128(imm(1), ymm2, xmm1 ) + //vaddpd( xmm2, xmm1, xmm2 ) - vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + //vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) // xmm4[0:3] = sum(ymm4) sum(ymm7) sum(ymm10) sum(ymm13) - //mov(var(rs_c), rdi) // load rs_c //lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) - + mov(var(alpha), rax) // load address of alpha mov(var(beta), rbx) // load address of beta vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate - + vmulpd(ymm0, ymm4, ymm4) // scale by alpha - - - - - - + + + + + + //mov(var(cs_c), rsi) // load cs_c //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) - - - + + + // now avoid loading C if beta == 0 - + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. vucomisd(xmm0, xmm3) // set ZF if beta == 0. je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case - - + + label(.DROWSTORED) - - + + vfmadd231pd(mem(rcx), ymm3, ymm4) vmovupd(ymm4, mem(rcx)) //add(rdi, rcx) - - - + + + jmp(.DDONE) // jump to end. - - - - + + + + label(.DBETAZERO) - - + + label(.DROWSTORBZ) - - + + vmovupd(ymm4, mem(rcx)) //add(rdi, rcx) - - - - + + + + label(.DDONE) label(.DRETURN) - - + + end_asm( : // output operands (none) diff --git a/kernels/haswell/3/sup/s6x16/bli_gemmsup_rv_haswell_asm_sMx2.c b/kernels/haswell/3/sup/s6x16/bli_gemmsup_rv_haswell_asm_sMx2.c index 6090f8b0b9..3d90e6e4f3 100644 --- a/kernels/haswell/3/sup/s6x16/bli_gemmsup_rv_haswell_asm_sMx2.c +++ b/kernels/haswell/3/sup/s6x16/bli_gemmsup_rv_haswell_asm_sMx2.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2019, Advanced Micro Devices, Inc. + Copyright (C) 2022, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -387,34 +387,39 @@ void bli_sgemmsup_rv_haswell_asm_6x2 label(.SROWSTORED) - - - vfmadd231ps(mem(rcx, 0*32), xmm3, xmm4) + + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm4) vmovsd(xmm4, mem(rcx, 0*32)) add(rdi, rcx) - - - vfmadd231ps(mem(rcx, 0*32), xmm3, xmm6) + + + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm6) vmovsd(xmm6, mem(rcx, 0*32)) add(rdi, rcx) - - - vfmadd231ps(mem(rcx, 0*32), xmm3, xmm8) + + + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm8) vmovsd(xmm8, mem(rcx, 0*32)) add(rdi, rcx) - - - vfmadd231ps(mem(rcx, 0*32), xmm3, xmm10) + + + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm10) vmovsd(xmm10, mem(rcx, 0*32)) add(rdi, rcx) - - - vfmadd231ps(mem(rcx, 0*32), xmm3, xmm12) + + + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm12) vmovsd(xmm12, mem(rcx, 0*32)) add(rdi, rcx) - - - vfmadd231ps(mem(rcx, 0*32), xmm3, xmm14) + + + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm14) vmovsd(xmm14, mem(rcx, 0*32)) //add(rdi, rcx) @@ -846,29 +851,33 @@ void bli_sgemmsup_rv_haswell_asm_5x2 label(.SROWSTORED) - - - vfmadd231ps(mem(rcx, 0*32), xmm3, xmm4) + + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm4) vmovsd(xmm4, mem(rcx, 0*32)) add(rdi, rcx) - - - vfmadd231ps(mem(rcx, 0*32), xmm3, xmm6) + + + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm6) vmovsd(xmm6, mem(rcx, 0*32)) add(rdi, rcx) - - - vfmadd231ps(mem(rcx, 0*32), xmm3, xmm8) + + + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm8) vmovsd(xmm8, mem(rcx, 0*32)) add(rdi, rcx) - - - vfmadd231ps(mem(rcx, 0*32), xmm3, xmm10) + + + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm10) vmovsd(xmm10, mem(rcx, 0*32)) add(rdi, rcx) - - - vfmadd231ps(mem(rcx, 0*32), xmm3, xmm12) + + + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm12) vmovsd(xmm12, mem(rcx, 0*32)) //add(rdi, rcx) @@ -1286,24 +1295,27 @@ void bli_sgemmsup_rv_haswell_asm_4x2 label(.SROWSTORED) - - - vfmadd231ps(mem(rcx, 0*32), xmm3, xmm4) + + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm4) vmovsd(xmm4, mem(rcx, 0*32)) add(rdi, rcx) - - - vfmadd231ps(mem(rcx, 0*32), xmm3, xmm6) + + + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm6) vmovsd(xmm6, mem(rcx, 0*32)) add(rdi, rcx) - - - vfmadd231ps(mem(rcx, 0*32), xmm3, xmm8) + + + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm8) vmovsd(xmm8, mem(rcx, 0*32)) add(rdi, rcx) - - - vfmadd231ps(mem(rcx, 0*32), xmm3, xmm10) + + + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm10) vmovsd(xmm10, mem(rcx, 0*32)) //add(rdi, rcx) @@ -1681,19 +1693,21 @@ void bli_sgemmsup_rv_haswell_asm_3x2 label(.SROWSTORED) - - - vfmadd231ps(mem(rcx, 0*32), xmm3, xmm4) + + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm4) vmovsd(xmm4, mem(rcx, 0*32)) add(rdi, rcx) - - - vfmadd231ps(mem(rcx, 0*32), xmm3, xmm6) + + + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm6) vmovsd(xmm6, mem(rcx, 0*32)) add(rdi, rcx) - - - vfmadd231ps(mem(rcx, 0*32), xmm3, xmm8) + + + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm8) vmovsd(xmm8, mem(rcx, 0*32)) //add(rdi, rcx) @@ -2064,14 +2078,15 @@ void bli_sgemmsup_rv_haswell_asm_2x2 label(.SROWSTORED) - - - vfmadd231ps(mem(rcx, 0*32), xmm3, xmm4) + + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm4) vmovsd(xmm4, mem(rcx, 0*32)) add(rdi, rcx) - - - vfmadd231ps(mem(rcx, 0*32), xmm3, xmm6) + + + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm6) vmovsd(xmm6, mem(rcx, 0*32)) //add(rdi, rcx) @@ -2402,9 +2417,9 @@ void bli_sgemmsup_rv_haswell_asm_1x2 label(.SROWSTORED) - - - vfmadd231ps(mem(rcx, 0*32), xmm3, xmm4) + + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm4) vmovsd(xmm4, mem(rcx, 0*32)) //add(rdi, rcx) diff --git a/kernels/haswell/bli_kernels_haswell.h b/kernels/haswell/bli_kernels_haswell.h index 1c35122a4e..d841d715f3 100644 --- a/kernels/haswell/bli_kernels_haswell.h +++ b/kernels/haswell/bli_kernels_haswell.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2019, Advanced Micro Devices, Inc. + Copyright (C) 2022, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -278,6 +278,38 @@ GEMMSUP_KER_PROT( double, d, gemmsup_rd_haswell_asm_1x1 ) // gemmsup_rd (mkernel in m dim) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_6x8m_0x0_U ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_6x8m_6x0_U ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_6x8m_6x8_U ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_6x8m_12x8_U ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_6x8m_12x16_U ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_6x8m_18x16_U ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_6x8m_0x0_combined_U ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_6x8m_0x0_L ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_6x8m_6x0_L ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_6x8m_6x8_L ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_6x8m_12x8_L ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_6x8m_12x16_L ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_6x8m_18x16_L ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_6x8m_16x12_combined_L ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_6x8m_0x0_combined_U ) + +GEMMSUP_KER_PROT( double, d, gemmsup_rd_haswell_asm_6x8m_0x0_U ) +GEMMSUP_KER_PROT( double, d, gemmsup_rd_haswell_asm_6x8m_6x0_U ) +GEMMSUP_KER_PROT( double, d, gemmsup_rd_haswell_asm_6x8m_6x8_U ) +GEMMSUP_KER_PROT( double, d, gemmsup_rd_haswell_asm_6x8m_12x8_U ) +GEMMSUP_KER_PROT( double, d, gemmsup_rd_haswell_asm_6x8m_12x16_U ) +GEMMSUP_KER_PROT( double, d, gemmsup_rd_haswell_asm_6x8m_18x16_U ) +GEMMSUP_KER_PROT( double, d, gemmsup_rd_haswell_asm_6x8m_0x0_combined_U ) +GEMMSUP_KER_PROT( double, d, gemmsup_rd_haswell_asm_6x8m_0x0_L ) +GEMMSUP_KER_PROT( double, d, gemmsup_rd_haswell_asm_6x8m_6x0_L ) +GEMMSUP_KER_PROT( double, d, gemmsup_rd_haswell_asm_6x8m_6x8_L ) +GEMMSUP_KER_PROT( double, d, gemmsup_rd_haswell_asm_6x8m_12x8_L ) +GEMMSUP_KER_PROT( double, d, gemmsup_rd_haswell_asm_6x8m_12x16_L ) +GEMMSUP_KER_PROT( double, d, gemmsup_rd_haswell_asm_6x8m_18x16_L ) +GEMMSUP_KER_PROT( double, d, gemmsup_rd_haswell_asm_6x8m_16x12_combined_L ) +GEMMSUP_KER_PROT( double, d, gemmsup_rd_haswell_asm_6x8m_0x0_combined_U ) + GEMMSUP_KER_PROT( double, d, gemmsup_rd_haswell_asm_6x8m ) GEMMSUP_KER_PROT( double, d, gemmsup_rd_haswell_asm_6x4m ) GEMMSUP_KER_PROT( double, d, gemmsup_rd_haswell_asm_6x2m ) diff --git a/kernels/skx/3/CMakeLists.txt b/kernels/skx/3/CMakeLists.txt new file mode 100644 index 0000000000..30857ba975 --- /dev/null +++ b/kernels/skx/3/CMakeLists.txt @@ -0,0 +1,7 @@ +##Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved.## + +target_sources("${PROJECT_NAME}" + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/bli_dgemm_skx_asm_16x14.c + ${CMAKE_CURRENT_SOURCE_DIR}/bli_sgemm_skx_asm_32x12_l2.c + ) diff --git a/kernels/skx/3/bli_dgemm_skx_asm_16x14.c b/kernels/skx/3/bli_dgemm_skx_asm_16x14.c index 136f315323..c0ada1eb66 100644 --- a/kernels/skx/3/bli_dgemm_skx_asm_16x14.c +++ b/kernels/skx/3/bli_dgemm_skx_asm_16x14.c @@ -5,6 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2022, Advanced Micro Devices, Inc.All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -175,33 +176,33 @@ void bli_dgemm_skx_asm_16x14( BEGIN_ASM() VXORPD(YMM( 4), YMM( 4), YMM( 4)) //clear out registers - VXORPD(YMM( 5), YMM( 5), YMM( 5)) - VXORPD(YMM( 6), YMM( 6), YMM( 6)) - VXORPD(YMM( 7), YMM( 7), YMM( 7)) - VXORPD(YMM( 8), YMM( 8), YMM( 8)) - VXORPD(YMM( 9), YMM( 9), YMM( 9)) - VXORPD(YMM(10), YMM(10), YMM(10)) - VXORPD(YMM(11), YMM(11), YMM(11)) - VXORPD(YMM(12), YMM(12), YMM(12)) - VXORPD(YMM(13), YMM(13), YMM(13)) - VXORPD(YMM(14), YMM(14), YMM(14)) - VXORPD(YMM(15), YMM(15), YMM(15)) - VXORPD(YMM(16), YMM(16), YMM(16)) - VXORPD(YMM(17), YMM(17), YMM(17)) - VXORPD(YMM(18), YMM(18), YMM(18)) - VXORPD(YMM(19), YMM(19), YMM(19)) - VXORPD(YMM(20), YMM(20), YMM(20)) - VXORPD(YMM(21), YMM(21), YMM(21)) - VXORPD(YMM(22), YMM(22), YMM(22)) - VXORPD(YMM(23), YMM(23), YMM(23)) - VXORPD(YMM(24), YMM(24), YMM(24)) - VXORPD(YMM(25), YMM(25), YMM(25)) - VXORPD(YMM(26), YMM(26), YMM(26)) - VXORPD(YMM(27), YMM(27), YMM(27)) - VXORPD(YMM(28), YMM(28), YMM(28)) - VXORPD(YMM(29), YMM(29), YMM(29)) - VXORPD(YMM(30), YMM(30), YMM(30)) - VXORPD(YMM(31), YMM(31), YMM(31)) + VMOVAPD(YMM(5) , YMM(4)) + VMOVAPD(YMM(6) , YMM(4)) + VMOVAPD(YMM(7) , YMM(4)) + VMOVAPD(YMM(8) , YMM(4)) + VMOVAPD(YMM(9) , YMM(4)) + VMOVAPD(YMM(10), YMM(4)) + VMOVAPD(YMM(11), YMM(4)) + VMOVAPD(YMM(12), YMM(4)) + VMOVAPD(YMM(13), YMM(4)) + VMOVAPD(YMM(14), YMM(4)) + VMOVAPD(YMM(15), YMM(4)) + VMOVAPD(YMM(16), YMM(4)) + VMOVAPD(YMM(17), YMM(4)) + VMOVAPD(YMM(18), YMM(4)) + VMOVAPD(YMM(19), YMM(4)) + VMOVAPD(YMM(20), YMM(4)) + VMOVAPD(YMM(21), YMM(4)) + VMOVAPD(YMM(22), YMM(4)) + VMOVAPD(YMM(23), YMM(4)) + VMOVAPD(YMM(24), YMM(4)) + VMOVAPD(YMM(25), YMM(4)) + VMOVAPD(YMM(26), YMM(4)) + VMOVAPD(YMM(27), YMM(4)) + VMOVAPD(YMM(28), YMM(4)) + VMOVAPD(YMM(29), YMM(4)) + VMOVAPD(YMM(30), YMM(4)) + VMOVAPD(YMM(31), YMM(4)) MOV(RSI, VAR(k)) //loop index MOV(RAX, VAR(a)) //load address of a diff --git a/kernels/skx/CMakeLists.txt b/kernels/skx/CMakeLists.txt new file mode 100644 index 0000000000..bc8f1eaab3 --- /dev/null +++ b/kernels/skx/CMakeLists.txt @@ -0,0 +1,4 @@ +##Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved.## + +add_subdirectory(3) + diff --git a/kernels/zen/1/CMakeLists.txt b/kernels/zen/1/CMakeLists.txt index 434be490d5..dbdd1533e2 100644 --- a/kernels/zen/1/CMakeLists.txt +++ b/kernels/zen/1/CMakeLists.txt @@ -1,4 +1,4 @@ -##Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved.## +##Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved.## target_sources("${PROJECT_NAME}" PRIVATE diff --git a/kernels/zen/1/bli_amaxv_zen_int.c b/kernels/zen/1/bli_amaxv_zen_int.c index e72705340e..7f799fa628 100644 --- a/kernels/zen/1/bli_amaxv_zen_int.c +++ b/kernels/zen/1/bli_amaxv_zen_int.c @@ -4,8 +4,8 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2016 - 2018 - 2019, Advanced Micro Devices, Inc. - Copyright (C) 2018, The University of Texas at Austin + Copyright (C) 2016 - 2022, Advanced Micro Devices, Inc. + Copyright (C) 2018, The University of Texas at Austin Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -36,7 +36,6 @@ #include "immintrin.h" #include "blis.h" - /* Union data structure to access AVX registers One 256-bit AVX register holds 8 SP elements. */ typedef union @@ -266,7 +265,6 @@ void bli_samaxv_zen_int } // ----------------------------------------------------------------------------- - void bli_damaxv_zen_int ( dim_t n, @@ -423,101 +421,3 @@ void bli_damaxv_zen_int *i_max = i_max_l; AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3) } - -// ----------------------------------------------------------------------------- - -#if 0 -#undef GENTFUNCR -#define GENTFUNCR( ctype, ctype_r, ch, chr, varname ) \ -\ -void PASTEMAC(ch,varname) \ - ( \ - dim_t n, \ - ctype* x, inc_t incx, \ - dim_t* i_max, \ - cntx_t* cntx \ - ) \ -{ \ - ctype_r* minus_one = PASTEMAC(chr,m1); \ - dim_t* zero_i = PASTEMAC(i,0); \ -\ - ctype_r chi1_r; \ - ctype_r chi1_i; \ - ctype_r abs_chi1; \ - ctype_r abs_chi1_max; \ - dim_t i; \ -\ - /* Initialize the index of the maximum absolute value to zero. */ \ - PASTEMAC(i,copys)( zero_i, *i_max ); \ -\ - /* If the vector length is zero, return early. This directly emulates - the behavior of netlib BLAS's i?amax() routines. */ \ - if ( bli_zero_dim1( n ) ) return; \ -\ - /* Initialize the maximum absolute value search candidate with - -1, which is guaranteed to be less than all values we will - compute. */ \ - PASTEMAC(chr,copys)( *minus_one, abs_chi1_max ); \ -\ - if ( incx == 1 ) \ - { \ - for ( i = 0; i < n; ++i ) \ - { \ - /* Get the real and imaginary components of chi1. */ \ - PASTEMAC2(ch,chr,gets)( x[i], chi1_r, chi1_i ); \ -\ - /* Replace chi1_r and chi1_i with their absolute values. */ \ - PASTEMAC(chr,abval2s)( chi1_r, chi1_r ); \ - PASTEMAC(chr,abval2s)( chi1_i, chi1_i ); \ -\ - /* Add the real and imaginary absolute values together. */ \ - PASTEMAC(chr,set0s)( abs_chi1 ); \ - PASTEMAC(chr,adds)( chi1_r, abs_chi1 ); \ - PASTEMAC(chr,adds)( chi1_i, abs_chi1 ); \ -\ - /* If the absolute value of the current element exceeds that of - the previous largest, save it and its index. If NaN is - encountered, then treat it the same as if it were a valid - value that was smaller than any previously seen. This - behavior mimics that of LAPACK's ?lange(). */ \ - if ( abs_chi1_max < abs_chi1 || bli_isnan( abs_chi1 ) ) \ - { \ - abs_chi1_max = abs_chi1; \ - *i_max = i; \ - } \ - } \ - } \ - else \ - { \ - for ( i = 0; i < n; ++i ) \ - { \ - ctype* chi1 = x + (i )*incx; \ -\ - /* Get the real and imaginary components of chi1. */ \ - PASTEMAC2(ch,chr,gets)( *chi1, chi1_r, chi1_i ); \ -\ - /* Replace chi1_r and chi1_i with their absolute values. */ \ - PASTEMAC(chr,abval2s)( chi1_r, chi1_r ); \ - PASTEMAC(chr,abval2s)( chi1_i, chi1_i ); \ -\ - /* Add the real and imaginary absolute values together. */ \ - PASTEMAC(chr,set0s)( abs_chi1 ); \ - PASTEMAC(chr,adds)( chi1_r, abs_chi1 ); \ - PASTEMAC(chr,adds)( chi1_i, abs_chi1 ); \ -\ - /* If the absolute value of the current element exceeds that of - the previous largest, save it and its index. If NaN is - encountered, then treat it the same as if it were a valid - value that was smaller than any previously seen. This - behavior mimics that of LAPACK's ?lange(). */ \ - if ( abs_chi1_max < abs_chi1 || bli_isnan( abs_chi1 ) ) \ - { \ - abs_chi1_max = abs_chi1; \ - *i_max = i; \ - } \ - } \ - } \ -} -GENTFUNCR( scomplex, float, c, s, amaxv_zen_int ) -GENTFUNCR( dcomplex, double, z, d, amaxv_zen_int ) -#endif diff --git a/kernels/zen/1/bli_norm2_zen_int.c b/kernels/zen/1/bli_norm2_zen_int.c index 0a0f92e36c..1971b79433 100644 --- a/kernels/zen/1/bli_norm2_zen_int.c +++ b/kernels/zen/1/bli_norm2_zen_int.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2021 - 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -34,203 +34,922 @@ #include "immintrin.h" #include "blis.h" -#ifdef BLIS_ENABLE_FAST_MATH -/* Union data structure to access AVX registers - One 256-bit AVX register holds 8 SP elements. */ +// Union data structure to access AVX registers +// One 256-bit AVX register holds 8 SP elements. typedef union { __m256 v; - float f[8] __attribute__((aligned(64))); + float f[8] __attribute__( ( aligned( 64 ) ) ); } v8sf_t; -/* Union data structure to access AVX registers -* One 256-bit AVX register holds 4 DP elements. */ +// Union data structure to access AVX registers +// One 256-bit AVX register holds 4 DP elements. typedef union { __m256d v; - double d[4] __attribute__((aligned(64))); + double d[4] __attribute__( ( aligned( 64 ) ) ); } v4df_t; -// ----------------------------------------------------------------------------- +// Return a mask which indicates either: +// v <= t or v >= T +#define CMP256( v, t, T ) \ + _mm256_or_pd( _mm256_cmp_pd( v, t, _CMP_LE_OS ), _mm256_cmp_pd( v, T, _CMP_GE_OS ) ); -void bli_dnorm2fv_unb_var1 - ( +// Returns true if any of the values in the mask vector is true, +// and false, otherwise. +static inline bool bli_horizontal_or( __m256d a ) { return ! _mm256_testz_pd( a, a ); } + +// Optimized function that computes the Frobenius norm using AVX2 intrinsics. +void bli_dnorm2fv_unb_var1_avx2 + ( dim_t n, double* x, inc_t incx, double* norm, cntx_t* cntx - ) + ) { + AOCL_DTL_TRACE_ENTRY( AOCL_DTL_LEVEL_TRACE_3 ); + double sumsq = 0; - double rem_sumsq = 0; /*sum of squares accumulated for n_remainder<8 cases.*/ + dim_t i = 0; dim_t n_remainder = 0; - dim_t i; - /*memory pool declarations for packing vector X. - Initialize mem pool buffer to NULL and size to 0 - "buf" and "size" fields are assigned once memory - is allocated from the pool in bli_membrk_acquire_m(). - This will ensure bli_mem_is_alloc() will be passed on - an allocated memory if created or a NULL .*/ - mem_t mem_bufX = {0}; - rntm_t rntm; double *x_buf = x; - - /*early return if n<=0 or incx =0 */ - if((n <= 0) || (incx == 0)) - return; - - /*packing for non-unit strided Vector X*/ - if(incx != 1) + + // Early return if n<=0 or incx=0 + if ( ( n <= 0) || ( incx == 0 ) ) { - /* In order to get the buffer from pool via rntm access to memory broker - is needed.Following are initializations for rntm */ + return; + } + + // Memory pool declarations for packing vector X. + // Initialize mem pool buffer to NULL and size to 0. + // "buf" and "size" fields are assigned once memory + // is allocated from the pool in bli_membrk_acquire_m(). + // This will ensure bli_mem_is_alloc() will be passed on + // an allocated memory if created or a NULL. + mem_t mem_bufX = {0}; + rntm_t rntm; + // Packing for non-unit strided vector x. + if ( incx != 1 ) + { + // In order to get the buffer from pool via rntm access to memory broker + //is needed. Following are initializations for rntm. bli_rntm_init_from_global( &rntm ); bli_rntm_set_num_threads_only( 1, &rntm ); bli_membrk_rntm_set_membrk( &rntm ); - //calculate the size required for "n" double elements in vector X. - size_t buffer_size = n * sizeof(double); + // Calculate the size required for "n" double elements in vector x. + size_t buffer_size = n * sizeof( double ); #ifdef BLIS_ENABLE_MEM_TRACING printf( "bli_dnorm2fv_unb_var1(): get mem pool block\n" ); #endif - /*acquire a Buffer(n*size(double)) from the memory broker - and save the associated mem_t entry to mem_bufX.*/ - bli_membrk_acquire_m(&rntm, - buffer_size, - BLIS_BUFFER_FOR_B_PANEL, - &mem_bufX); - - /*Continue packing X if buffer memory is allocated*/ - if ((bli_mem_is_alloc( &mem_bufX ))) + // Acquire a Buffer(n*size(double)) from the memory broker + // and save the associated mem_t entry to mem_bufX. + bli_membrk_acquire_m + ( + &rntm, + buffer_size, + BLIS_BUFFER_FOR_B_PANEL, + &mem_bufX + ); + + // Continue packing X if buffer memory is allocated. + if ( ( bli_mem_is_alloc( &mem_bufX ) ) ) { - x_buf = bli_mem_buffer(&mem_bufX); - - /*pack X vector with non-unit stride to a temp buffer x_buf with unit stride*/ - for(dim_t x_index = 0 ; x_index < n ; x_index++) + x_buf = bli_mem_buffer( &mem_bufX ); + // Pack vector x with non-unit stride to a temp buffer x_buf with unit stride. + for ( dim_t x_index = 0; x_index < n; x_index++ ) { - *(x_buf + x_index) = *(x + (x_index * incx)) ; + if ( incx > 0 ) + { + *( x_buf + x_index ) = *( x + ( x_index * incx ) ); + } + else + { + *( x_buf + x_index ) = *( x + ( - ( n - x_index - 1 ) * incx ) ); + } } } } - v4df_t x0v, x1v, x2v, x3v, x4v, x5v, x6v, x7v; - /* Initialize rho vector accumulators to zero.*/ - v4df_t rho0v; rho0v.v = _mm256_setzero_pd(); - v4df_t rho1v; rho1v.v = _mm256_setzero_pd(); - v4df_t rho2v; rho2v.v = _mm256_setzero_pd(); - v4df_t rho3v; rho3v.v = _mm256_setzero_pd(); - v4df_t rho4v; rho4v.v = _mm256_setzero_pd(); - v4df_t rho5v; rho5v.v = _mm256_setzero_pd(); - v4df_t rho6v; rho6v.v = _mm256_setzero_pd(); - v4df_t rho7v; rho7v.v = _mm256_setzero_pd(); + double *xt = x_buf; + + // Compute the sum of squares on 3 accumulators to avoid overflow + // and underflow, depending on the vector element value. + // Accumulator for small values; using scaling to avoid underflow. + double sum_sml = 0; + // Accumulator for medium values; no scaling required. + double sum_med = 0; + // Accumulator for big values; using scaling to avoid overflow. + double sum_big = 0; - double *x0 = x_buf; + // Constants chosen to minimize roundoff, according to Blue's algorithm. + const double thres_sml = pow( ( double )FLT_RADIX, ceil( ( DBL_MIN_EXP - 1 ) * 0.5 ) ); + const double thres_big = pow( ( double )FLT_RADIX, floor( ( DBL_MAX_EXP - 52) * 0.5 ) ); + const double scale_sml = pow( ( double )FLT_RADIX, - floor( ( DBL_MIN_EXP - 53 ) * 0.5 ) ); + const double scale_big = pow( ( double )FLT_RADIX, - ceil( ( DBL_MAX_EXP - 52 ) * 0.5 ) ); - for(i = 0 ; i+31 < n ; i = i + 32) + double scale; + double abs_chi; + bool isbig = false; + + if ( n > 4 ) { + // Constants used for comparisons. + v4df_t temp, thres_sml_vec, thres_big_vec, zerov, ymm0, ymm1; + temp.v = _mm256_set1_pd( -0.0 ); + thres_sml_vec.v = _mm256_set1_pd( thres_sml ); + thres_big_vec.v = _mm256_set1_pd( thres_big ); + v4df_t x0v, x1v, mask_vec0, mask_vec1; + zerov.v = _mm256_setzero_pd(); + + // Partial sums used for scaling. + v4df_t sum_med_vec0, sum_big_vec0, sum_sml_vec0, sum_med_vec1, sum_big_vec1, sum_sml_vec1; + sum_med_vec0.v = _mm256_setzero_pd(); + sum_big_vec0.v = _mm256_setzero_pd(); + sum_sml_vec0.v = _mm256_setzero_pd(); + sum_med_vec1.v = _mm256_setzero_pd(); + sum_big_vec1.v = _mm256_setzero_pd(); + sum_sml_vec1.v = _mm256_setzero_pd(); + + for (; ( i + 8 ) <= n; i = i + 8) + { + x0v.v = _mm256_loadu_pd( xt ); + x1v.v = _mm256_loadu_pd( xt + 4 ); + + // Getting the abs of the vector elements. + x0v.v = _mm256_andnot_pd( temp.v, x0v.v ); + x1v.v = _mm256_andnot_pd( temp.v, x1v.v ); + + // Check if any of the values is a NaN and if so, return. + mask_vec0.v = _mm256_cmp_pd(x0v.v, x0v.v, _CMP_UNORD_Q); + mask_vec1.v = _mm256_cmp_pd(x1v.v, x1v.v, _CMP_UNORD_Q); + if ( bli_horizontal_or( mask_vec0.v ) ) + { + *norm = NAN; + return; + } + if ( bli_horizontal_or( mask_vec1.v ) ) + { + *norm = NAN; + return; + } + + // Mask vectors which indicate whether + // xi<=thres_sml or xi>=thres_big. + mask_vec0.v = CMP256( x0v.v, thres_sml_vec.v, thres_big_vec.v ); + mask_vec1.v = CMP256( x1v.v, thres_sml_vec.v, thres_big_vec.v ); - x0v.v = _mm256_loadu_pd( x0 ); - x1v.v = _mm256_loadu_pd( x0 + 4 ); - x2v.v = _mm256_loadu_pd( x0 + 8 ); - x3v.v = _mm256_loadu_pd( x0 + 12 ); - x4v.v = _mm256_loadu_pd( x0 + 16 ); - x5v.v = _mm256_loadu_pd( x0 + 20 ); - x6v.v = _mm256_loadu_pd( x0 + 24 ); - x7v.v = _mm256_loadu_pd( x0 + 28 ); + if ( !bli_horizontal_or( mask_vec0.v ) ) + { + // Scaling is not necessary; only medium values. + sum_med_vec0.v = _mm256_fmadd_pd( x0v.v, x0v.v, sum_med_vec0.v ); + } + else + { + // Mask vector which indicate whether xi > thres_big. + mask_vec0.v = _mm256_cmp_pd( x0v.v, thres_big_vec.v, _CMP_GT_OQ ); + + if ( bli_horizontal_or( mask_vec0.v ) ) + { + isbig = true; + + // Fill sum_med vector without scaling. + ymm0.v = _mm256_blendv_pd( x0v.v, zerov.v, mask_vec0.v ); + sum_med_vec0.v = _mm256_fmadd_pd( ymm0.v, ymm0.v, sum_med_vec0.v ); + + // Fill sum_big vector using scaling. + temp.v = _mm256_set1_pd( scale_big ); + ymm0.v = _mm256_blendv_pd( zerov.v, temp.v, mask_vec0.v ); + ymm0.v = _mm256_mul_pd( x0v.v, ymm0.v ); + sum_big_vec0.v = _mm256_fmadd_pd( ymm0.v, ymm0.v, sum_big_vec0.v ); + temp.v = _mm256_set1_pd( -0.0 ); + } + else + { + // Mask vector which indicates whether xi > thres_small. + mask_vec0.v = _mm256_cmp_pd( x0v.v, thres_sml_vec.v, _CMP_LT_OQ ); + // Fill sum_med vector without scaling. + ymm0.v = _mm256_blendv_pd( x0v.v, zerov.v, mask_vec0.v ); + sum_med_vec0.v = _mm256_fmadd_pd( ymm0.v, ymm0.v, sum_med_vec0.v ); + + // Accumulate small values only if there have not been any big values so far. + if ( !isbig ) + { + // Fill sum_sml vector using scaling. + temp.v = _mm256_set1_pd( scale_sml ); + ymm0.v = _mm256_blendv_pd( zerov.v, temp.v, mask_vec0.v ); + ymm0.v = _mm256_mul_pd( x0v.v, ymm0.v ); + sum_sml_vec0.v = _mm256_fmadd_pd( ymm0.v, ymm0.v, sum_sml_vec0.v ); + temp.v = _mm256_set1_pd( -0.0 ); + } + } + } + + if ( !bli_horizontal_or( mask_vec1.v ) ) + { + // Scaling is not necessary; only medium values. + sum_med_vec1.v = _mm256_fmadd_pd( x1v.v, x1v.v, sum_med_vec1.v ); + } + else + { + // Mask vector which indicate whether xi > thres_big. + mask_vec1.v = _mm256_cmp_pd( x1v.v, thres_big_vec.v, _CMP_GT_OQ ); + + if ( bli_horizontal_or( mask_vec1.v ) ) + { + isbig = true; + + // Fill sum_med vector without scaling. + ymm1.v = _mm256_blendv_pd( x1v.v, zerov.v, mask_vec1.v ); + sum_med_vec1.v = _mm256_fmadd_pd( ymm1.v, ymm1.v, sum_med_vec1.v ); + + // Fill sum_big vector using scaling. + temp.v = _mm256_set1_pd( scale_big ); + ymm1.v = _mm256_blendv_pd( zerov.v, temp.v, mask_vec1.v ); + ymm1.v = _mm256_mul_pd( x1v.v, ymm1.v ); + sum_big_vec1.v = _mm256_fmadd_pd( ymm1.v, ymm1.v, sum_big_vec1.v ); + temp.v = _mm256_set1_pd( -0.0 ); + } + else + { + // Mask vector which indicates whether xi > thres_small. + mask_vec1.v = _mm256_cmp_pd( x1v.v, thres_sml_vec.v, _CMP_LT_OQ ); + // Fill sum_med vector without scaling. + ymm1.v = _mm256_blendv_pd( x1v.v, zerov.v, mask_vec1.v ); + sum_med_vec1.v = _mm256_fmadd_pd( ymm1.v, ymm1.v, sum_med_vec1.v ); + + // Accumulate small values only if there have not been any big values so far. + if ( !isbig ) + { + // Fill sum_sml vector using scaling. + temp.v = _mm256_set1_pd( scale_sml ); + ymm1.v = _mm256_blendv_pd( zerov.v, temp.v, mask_vec1.v ); + ymm1.v = _mm256_mul_pd( x1v.v, ymm1.v ); + sum_sml_vec1.v = _mm256_fmadd_pd( ymm1.v, ymm1.v, sum_sml_vec1.v ); + temp.v = _mm256_set1_pd( -0.0 ); + } + } + } + + xt += 8; + } + + for ( ; ( i + 4 ) <= n; i = i + 4 ) + { + x0v.v = _mm256_loadu_pd( xt ); + + // Getting the abs of the vector elements. + x0v.v = _mm256_andnot_pd( temp.v, x0v.v ); + + // Check if any of the values is a NaN and if so, return. + mask_vec0.v = _mm256_cmp_pd(x0v.v, x0v.v, _CMP_UNORD_Q); + if ( bli_horizontal_or( mask_vec0.v ) ) + { + *norm = NAN; + return; + } + + // Mask vectors which indicate whether + // xi<=thres_sml or xi>=thres_big. + mask_vec0.v = CMP256( x0v.v, thres_sml_vec.v, thres_big_vec.v ); + + if ( !bli_horizontal_or( mask_vec0.v ) ) + { + // Scaling is not necessary; only medium values. + sum_med_vec0.v = _mm256_fmadd_pd( x0v.v, x0v.v, sum_med_vec0.v ); + } + else + { + // Mask vector which indicate whether xi > thres_big. + mask_vec0.v = _mm256_cmp_pd( x0v.v, thres_big_vec.v, _CMP_GT_OQ ); + + if ( bli_horizontal_or( mask_vec0.v ) ) + { + isbig = true; + + // Fill sum_med vector without scaling. + ymm0.v = _mm256_blendv_pd( x0v.v, zerov.v, mask_vec0.v ); + sum_med_vec0.v = _mm256_fmadd_pd( ymm0.v, ymm0.v, sum_med_vec0.v ); + + // Fill sum_big vector using scaling. + temp.v = _mm256_set1_pd( scale_big ); + ymm0.v = _mm256_blendv_pd( zerov.v, temp.v, mask_vec0.v ); + ymm0.v = _mm256_mul_pd( x0v.v, ymm0.v ); + sum_big_vec0.v = _mm256_fmadd_pd( ymm0.v, ymm0.v, sum_big_vec0.v ); + temp.v = _mm256_set1_pd( -0.0 ); + } + else + { + // Mask vector which indicates whether xi > thres_small. + mask_vec0.v = _mm256_cmp_pd( x0v.v, thres_sml_vec.v, _CMP_LT_OQ ); + // Fill sum_med vector without scaling. + ymm0.v = _mm256_blendv_pd( x0v.v, zerov.v, mask_vec0.v ); + sum_med_vec0.v = _mm256_fmadd_pd( ymm0.v, ymm0.v, sum_med_vec0.v ); + + // Accumulate small values only if there have not been any big values so far. + if ( !isbig ) + { + // Fill sum_sml vector using scaling. + temp.v = _mm256_set1_pd( scale_sml ); + ymm0.v = _mm256_blendv_pd( zerov.v, temp.v, mask_vec0.v ); + ymm0.v = _mm256_mul_pd( x0v.v, ymm0.v ); + sum_sml_vec0.v = _mm256_fmadd_pd( ymm0.v, ymm0.v, sum_sml_vec0.v ); + temp.v = _mm256_set1_pd( -0.0 ); + } + } + } + xt += 4; + } - rho0v.v = _mm256_fmadd_pd(x0v.v, x0v.v, rho0v.v); - rho1v.v = _mm256_fmadd_pd(x1v.v, x1v.v, rho1v.v); - rho2v.v = _mm256_fmadd_pd(x2v.v, x2v.v, rho2v.v); - rho3v.v = _mm256_fmadd_pd(x3v.v, x3v.v, rho3v.v); - rho4v.v = _mm256_fmadd_pd(x4v.v, x4v.v, rho4v.v); - rho5v.v = _mm256_fmadd_pd(x5v.v, x5v.v, rho5v.v); - rho6v.v = _mm256_fmadd_pd(x6v.v, x6v.v, rho6v.v); - rho7v.v = _mm256_fmadd_pd(x7v.v, x7v.v, rho7v.v); + sum_sml_vec0.v = _mm256_add_pd( sum_sml_vec0.v, sum_sml_vec1.v ); + sum_med_vec0.v = _mm256_add_pd( sum_med_vec0.v, sum_med_vec1.v ); + sum_big_vec0.v = _mm256_add_pd( sum_big_vec0.v, sum_big_vec1.v ); - x0 += 32; + sum_sml += sum_sml_vec0.v[0] + sum_sml_vec0.v[1] + + sum_sml_vec0.v[2] + sum_sml_vec0.v[3]; + sum_med += sum_med_vec0.v[0] + sum_med_vec0.v[1] + + sum_med_vec0.v[2] + sum_med_vec0.v[3]; + sum_big += sum_big_vec0.v[0] + sum_big_vec0.v[1] + + sum_big_vec0.v[2] + sum_big_vec0.v[3]; } n_remainder = n - i; + bool hasInf = false; + if ( ( n_remainder > 0 ) ) + { + // Put first the most likely to happen to avoid evaluations on if statements. + for (i = 0; i < n_remainder; i++) + { + abs_chi = bli_fabs( *xt ); + // If any of the elements is NaN, then return NaN as a result. + if ( bli_isnan( abs_chi ) ) + { + *norm = abs_chi; + return; + } + // Else, if any of the elements is an Inf, then return +Inf as a result. + if ( bli_isinf( abs_chi ) ) + { + *norm = abs_chi; + // Instead of returning immediately, use this flag + // to denote that there is an Inf element in the vector. + // That is used to avoid cases where there is a NaN which comes + // after an Inf. + hasInf = true; + } + // Most likely case: medium values, not over/under-flow. + if ( ( abs_chi <= thres_big ) && ( abs_chi >= thres_sml ) ) + { + sum_med += abs_chi * abs_chi; + } + // Case where there could be an overflow. Scaling is required. + else if ( abs_chi > thres_big ) + { + sum_big += ( abs_chi * scale_big ) * ( abs_chi * scale_big ); + isbig = true; + } + // Case where there could be an underflow. Scaling is required. + else if ( ( !isbig ) && ( abs_chi < thres_sml ) ) + { + sum_sml += ( abs_chi * scale_sml ) * ( abs_chi * scale_sml ); + } + xt++; + } + } + + // Early return if there is an Inf. + if ( hasInf ) return; + + // Combine accumulators. + if ( isbig ) + { + // Combine sum_big and sum_med if sum_med > 0. + if ( sum_med > 0.0 ) + { + sum_big += ( sum_med * scale_big ) * scale_big; + } + scale = 1.0 / scale_big; + sumsq = sum_big; + } - if(n_remainder) + else if ( sum_sml > 0.0 ) { - if(n_remainder >= 16) + // Combine sum_med and sum_sml if sum_sml>0. + if ( sum_med > 0.0 ) { - x0v.v = _mm256_loadu_pd( x0 ); - x1v.v = _mm256_loadu_pd( x0 + 4 ); - x2v.v = _mm256_loadu_pd( x0 + 8 ); - x3v.v = _mm256_loadu_pd( x0 + 12 ); - - rho0v.v = _mm256_fmadd_pd(x0v.v, x0v.v, rho0v.v); - rho1v.v = _mm256_fmadd_pd(x1v.v, x1v.v, rho1v.v); - rho2v.v = _mm256_fmadd_pd(x2v.v, x2v.v, rho2v.v); - rho3v.v = _mm256_fmadd_pd(x3v.v, x3v.v, rho3v.v); - - x0 += 16; - n_remainder -= 16; + sum_med = sqrt( sum_med ); + sum_sml = sqrt( sum_sml ) / scale_sml; + double ymin, ymax; + if ( sum_sml > sum_med ) + { + ymin = sum_med; + ymax = sum_sml; + } + else + { + ymin = sum_sml; + ymax = sum_med; + } + scale = 1.0; + sumsq = ymax * ymax * ( 1.0 + ( ymin / ymax ) * ( ymin / ymax ) ); } - if(n_remainder >= 8) + else { - x0v.v = _mm256_loadu_pd( x0 ); - x1v.v = _mm256_loadu_pd( x0 + 4 ); + scale = 1.0 / scale_sml; + sumsq = sum_sml; + } + } + else + { + // If all values are mid-range: + scale = 1.0; + sumsq = sum_med; + } + + *norm = scale * sqrt( sumsq ); + + if ( ( incx != 1 ) && bli_mem_is_alloc( &mem_bufX ) ) + { + #ifdef BLIS_ENABLE_MEM_TRACING + printf( "bli_dnorm2fv_unb_var1(): releasing mem pool block\n" ); + #endif + // Return the buffer to pool. + bli_membrk_release( &rntm , &mem_bufX ); + } + + AOCL_DTL_TRACE_EXIT( AOCL_DTL_LEVEL_TRACE_3 ); - rho0v.v = _mm256_fmadd_pd(x0v.v, x0v.v, rho0v.v); - rho1v.v = _mm256_fmadd_pd(x1v.v, x1v.v, rho1v.v); + return; +} + +// Optimized function that computes the Frobenius norm using AVX2 intrinsics. +void bli_dznorm2fv_unb_var1_avx2 + ( + dim_t n, + dcomplex* x, inc_t incx, + double* norm, + cntx_t* cntx + ) +{ + AOCL_DTL_TRACE_ENTRY( AOCL_DTL_LEVEL_TRACE_3 ); - x0 += 8; - n_remainder -= 8; + double sumsq = 0; + dim_t i = 0; + dim_t n_remainder = 0; + dcomplex *x_buf = x; + + // Early return if n<=0 or incx=0 + if ( ( n <= 0) || ( incx == 0 ) ) + { + return; + } + + // Memory pool declarations for packing vector X. + // Initialize mem pool buffer to NULL and size to 0. + // "buf" and "size" fields are assigned once memory + // is allocated from the pool in bli_membrk_acquire_m(). + // This will ensure bli_mem_is_alloc() will be passed on + // an allocated memory if created or a NULL. + mem_t mem_bufX = {0}; + rntm_t rntm; + + // Packing for non-unit strided vector x. + if ( incx != 1 ) + { + // In order to get the buffer from pool via rntm access to memory broker + //is needed. Following are initializations for rntm. + bli_rntm_init_from_global( &rntm ); + bli_rntm_set_num_threads_only( 1, &rntm ); + bli_membrk_rntm_set_membrk( &rntm ); + + // Calculate the size required for "n" dcomplex elements in vector x. + size_t buffer_size = n * sizeof( dcomplex ); + + #ifdef BLIS_ENABLE_MEM_TRACING + printf( "bli_dznorm2fv_unb_var1(): get mem pool block\n" ); + #endif + + // Acquire a Buffer(n*size(dcomplex)) from the memory broker + // and save the associated mem_t entry to mem_bufX. + bli_membrk_acquire_m + ( + &rntm, + buffer_size, + BLIS_BUFFER_FOR_B_PANEL, + &mem_bufX + ); + + // Continue packing X if buffer memory is allocated. + if ( ( bli_mem_is_alloc( &mem_bufX ) ) ) + { + x_buf = bli_mem_buffer( &mem_bufX ); + // Pack vector x with non-unit stride to a temp buffer x_buf with unit stride. + for ( dim_t x_index = 0; x_index < n; x_index++ ) + { + if ( incx > 0 ) + { + *( x_buf + x_index ) = *( x + ( x_index * incx ) ); + } + else + { + *( x_buf + x_index ) = *( x + ( - ( n - x_index - 1 ) * incx ) ); + } + } } - if(n_remainder >= 4) + } + + dcomplex *xt = x_buf; + + // Compute the sum of squares on 3 accumulators to avoid overflow + // and underflow, depending on the vector element value. + // Accumulator for small values; using scaling to avoid underflow. + double sum_sml = 0; + // Accumulator for medium values; no scaling required. + double sum_med = 0; + // Accumulator for big values; using scaling to avoid overflow. + double sum_big = 0; + + // Constants chosen to minimize roundoff, according to Blue's algorithm. + const double thres_sml = pow( ( double )FLT_RADIX, ceil( ( DBL_MIN_EXP - 1 ) * 0.5 ) ); + const double thres_big = pow( ( double )FLT_RADIX, floor( ( DBL_MAX_EXP - 52) * 0.5 ) ); + const double scale_sml = pow( ( double )FLT_RADIX, - floor( ( DBL_MIN_EXP - 53 ) * 0.5 ) ); + const double scale_big = pow( ( double )FLT_RADIX, - ceil( ( DBL_MAX_EXP - 52 ) * 0.5 ) ); + + double scale; + double abs_chi; + bool isbig = false; + + if ( n > 2 ) + { + // Constants used for comparisons. + v4df_t temp, thres_sml_vec, thres_big_vec, zerov, ymm0, ymm1; + temp.v = _mm256_set1_pd( -0.0 ); + thres_sml_vec.v = _mm256_set1_pd( thres_sml ); + thres_big_vec.v = _mm256_set1_pd( thres_big ); + v4df_t x0v, x1v, mask_vec0, mask_vec1; + zerov.v = _mm256_setzero_pd(); + + // Partial sums used for scaling. + v4df_t sum_med_vec0, sum_big_vec0, sum_sml_vec0, sum_med_vec1, sum_big_vec1, sum_sml_vec1; + sum_med_vec0.v = _mm256_setzero_pd(); + sum_big_vec0.v = _mm256_setzero_pd(); + sum_sml_vec0.v = _mm256_setzero_pd(); + sum_med_vec1.v = _mm256_setzero_pd(); + sum_big_vec1.v = _mm256_setzero_pd(); + sum_sml_vec1.v = _mm256_setzero_pd(); + + for (; ( i + 4 ) <= n; i = i + 4) { - x0v.v = _mm256_loadu_pd( x0 ); + x0v.v = _mm256_loadu_pd( (double*) xt ); + x1v.v = _mm256_loadu_pd( (double*) (xt + 2) ); + + // Getting the abs of the vector elements. + x0v.v = _mm256_andnot_pd( temp.v, x0v.v ); + x1v.v = _mm256_andnot_pd( temp.v, x1v.v ); - rho0v.v = _mm256_fmadd_pd(x0v.v, x0v.v, rho0v.v); + // Check if any of the values is a NaN and if so, return. + mask_vec0.v = _mm256_cmp_pd(x0v.v, x0v.v, _CMP_UNORD_Q); + mask_vec1.v = _mm256_cmp_pd(x1v.v, x1v.v, _CMP_UNORD_Q); + if ( bli_horizontal_or( mask_vec0.v ) ) + { + *norm = NAN; + return; + } + if ( bli_horizontal_or( mask_vec1.v ) ) + { + *norm = NAN; + return; + } + + // Mask vectors which indicate whether + // xi<=thres_sml or xi>=thres_big. + mask_vec0.v = CMP256( x0v.v, thres_sml_vec.v, thres_big_vec.v ); + mask_vec1.v = CMP256( x1v.v, thres_sml_vec.v, thres_big_vec.v ); + + if ( !bli_horizontal_or( mask_vec0.v ) ) + { + // Scaling is not necessary; only medium values. + sum_med_vec0.v = _mm256_fmadd_pd( x0v.v, x0v.v, sum_med_vec0.v ); + } + else + { + // Mask vector which indicate whether xi > thres_big. + mask_vec0.v = _mm256_cmp_pd( x0v.v, thres_big_vec.v, _CMP_GT_OQ ); + + if ( bli_horizontal_or( mask_vec0.v ) ) + { + isbig = true; + + // Fill sum_med vector without scaling. + ymm0.v = _mm256_blendv_pd( x0v.v, zerov.v, mask_vec0.v ); + sum_med_vec0.v = _mm256_fmadd_pd( ymm0.v, ymm0.v, sum_med_vec0.v ); + + // Fill sum_big vector using scaling. + temp.v = _mm256_set1_pd( scale_big ); + ymm0.v = _mm256_blendv_pd( zerov.v, temp.v, mask_vec0.v ); + ymm0.v = _mm256_mul_pd( x0v.v, ymm0.v ); + sum_big_vec0.v = _mm256_fmadd_pd( ymm0.v, ymm0.v, sum_big_vec0.v ); + temp.v = _mm256_set1_pd( -0.0 ); + } + else + { + // Mask vector which indicates whether xi > thres_small. + mask_vec0.v = _mm256_cmp_pd( x0v.v, thres_sml_vec.v, _CMP_LT_OQ ); + // Fill sum_med vector without scaling. + ymm0.v = _mm256_blendv_pd( x0v.v, zerov.v, mask_vec0.v ); + sum_med_vec0.v = _mm256_fmadd_pd( ymm0.v, ymm0.v, sum_med_vec0.v ); + + // Accumulate small values only if there have not been any big values so far. + if ( !isbig ) + { + // Fill sum_sml vector using scaling. + temp.v = _mm256_set1_pd( scale_sml ); + ymm0.v = _mm256_blendv_pd( zerov.v, temp.v, mask_vec0.v ); + ymm0.v = _mm256_mul_pd( x0v.v, ymm0.v ); + sum_sml_vec0.v = _mm256_fmadd_pd( ymm0.v, ymm0.v, sum_sml_vec0.v ); + temp.v = _mm256_set1_pd( -0.0 ); + } + } + } + + if ( !bli_horizontal_or( mask_vec1.v ) ) + { + // Scaling is not necessary; only medium values. + sum_med_vec1.v = _mm256_fmadd_pd( x1v.v, x1v.v, sum_med_vec1.v ); + } + else + { + // Mask vector which indicate whether xi > thres_big. + mask_vec1.v = _mm256_cmp_pd( x1v.v, thres_big_vec.v, _CMP_GT_OQ ); + + if ( bli_horizontal_or( mask_vec1.v ) ) + { + isbig = true; + + // Fill sum_med vector without scaling. + ymm1.v = _mm256_blendv_pd( x1v.v, zerov.v, mask_vec1.v ); + sum_med_vec1.v = _mm256_fmadd_pd( ymm1.v, ymm1.v, sum_med_vec1.v ); + + // Fill sum_big vector using scaling. + temp.v = _mm256_set1_pd( scale_big ); + ymm1.v = _mm256_blendv_pd( zerov.v, temp.v, mask_vec1.v ); + ymm1.v = _mm256_mul_pd( x1v.v, ymm1.v ); + sum_big_vec1.v = _mm256_fmadd_pd( ymm1.v, ymm1.v, sum_big_vec1.v ); + temp.v = _mm256_set1_pd( -0.0 ); + } + else + { + // Mask vector which indicates whether xi > thres_small. + mask_vec1.v = _mm256_cmp_pd( x1v.v, thres_sml_vec.v, _CMP_LT_OQ ); + // Fill sum_med vector without scaling. + ymm1.v = _mm256_blendv_pd( x1v.v, zerov.v, mask_vec1.v ); + sum_med_vec1.v = _mm256_fmadd_pd( ymm1.v, ymm1.v, sum_med_vec1.v ); + + // Accumulate small values only if there have not been any big values so far. + if ( !isbig ) + { + // Fill sum_sml vector using scaling. + temp.v = _mm256_set1_pd( scale_sml ); + ymm1.v = _mm256_blendv_pd( zerov.v, temp.v, mask_vec1.v ); + ymm1.v = _mm256_mul_pd( x1v.v, ymm1.v ); + sum_sml_vec1.v = _mm256_fmadd_pd( ymm1.v, ymm1.v, sum_sml_vec1.v ); + temp.v = _mm256_set1_pd( -0.0 ); + } + } + } - x0 += 4; - n_remainder -= 4; + xt += 4; } - if(n_remainder) + + for ( ; ( i + 2 ) <= n; i = i + 2 ) { - for(i=0; i< n_remainder ;i++) + x0v.v = _mm256_loadu_pd( (double*) xt ); + + // Getting the abs of the vector elements. + x0v.v = _mm256_andnot_pd( temp.v, x0v.v ); + + // Check if any of the values is a NaN and if so, return. + mask_vec0.v = _mm256_cmp_pd(x0v.v, x0v.v, _CMP_UNORD_Q); + if ( bli_horizontal_or( mask_vec0.v ) ) + { + *norm = NAN; + return; + } + + // Mask vectors which indicate whether + // xi<=thres_sml or xi>=thres_big. + mask_vec0.v = CMP256( x0v.v, thres_sml_vec.v, thres_big_vec.v ); + + if ( !bli_horizontal_or( mask_vec0.v ) ) + { + // Scaling is not necessary; only medium values. + sum_med_vec0.v = _mm256_fmadd_pd( x0v.v, x0v.v, sum_med_vec0.v ); + } + else { - double x_temp = *x0; - rem_sumsq += x_temp * x_temp ; - x0 += 1; + // Mask vector which indicate whether xi > thres_big. + mask_vec0.v = _mm256_cmp_pd( x0v.v, thres_big_vec.v, _CMP_GT_OQ ); + + if ( bli_horizontal_or( mask_vec0.v ) ) + { + isbig = true; + + // Fill sum_med vector without scaling. + ymm0.v = _mm256_blendv_pd( x0v.v, zerov.v, mask_vec0.v ); + sum_med_vec0.v = _mm256_fmadd_pd( ymm0.v, ymm0.v, sum_med_vec0.v ); + + // Fill sum_big vector using scaling. + temp.v = _mm256_set1_pd( scale_big ); + ymm0.v = _mm256_blendv_pd( zerov.v, temp.v, mask_vec0.v ); + ymm0.v = _mm256_mul_pd( x0v.v, ymm0.v ); + sum_big_vec0.v = _mm256_fmadd_pd( ymm0.v, ymm0.v, sum_big_vec0.v ); + temp.v = _mm256_set1_pd( -0.0 ); + } + else + { + // Mask vector which indicates whether xi > thres_small. + mask_vec0.v = _mm256_cmp_pd( x0v.v, thres_sml_vec.v, _CMP_LT_OQ ); + // Fill sum_med vector without scaling. + ymm0.v = _mm256_blendv_pd( x0v.v, zerov.v, mask_vec0.v ); + sum_med_vec0.v = _mm256_fmadd_pd( ymm0.v, ymm0.v, sum_med_vec0.v ); + + // Accumulate small values only if there have not been any big values so far. + if ( !isbig ) + { + // Fill sum_sml vector using scaling. + temp.v = _mm256_set1_pd( scale_sml ); + ymm0.v = _mm256_blendv_pd( zerov.v, temp.v, mask_vec0.v ); + ymm0.v = _mm256_mul_pd( x0v.v, ymm0.v ); + sum_sml_vec0.v = _mm256_fmadd_pd( ymm0.v, ymm0.v, sum_sml_vec0.v ); + temp.v = _mm256_set1_pd( -0.0 ); + } + } } + xt += 2; } + + sum_sml_vec0.v = _mm256_add_pd( sum_sml_vec0.v, sum_sml_vec1.v ); + sum_med_vec0.v = _mm256_add_pd( sum_med_vec0.v, sum_med_vec1.v ); + sum_big_vec0.v = _mm256_add_pd( sum_big_vec0.v, sum_big_vec1.v ); + + sum_sml += sum_sml_vec0.v[0] + sum_sml_vec0.v[1] + + sum_sml_vec0.v[2] + sum_sml_vec0.v[3]; + sum_med += sum_med_vec0.v[0] + sum_med_vec0.v[1] + + sum_med_vec0.v[2] + sum_med_vec0.v[3]; + sum_big += sum_big_vec0.v[0] + sum_big_vec0.v[1] + + sum_big_vec0.v[2] + sum_big_vec0.v[3]; } - /*add all the dot product of x*x into one vector .*/ - rho0v.v = _mm256_add_pd ( rho0v.v, rho1v.v ); - rho1v.v = _mm256_add_pd ( rho2v.v, rho3v.v ); - rho2v.v = _mm256_add_pd ( rho4v.v, rho5v.v ); - rho3v.v = _mm256_add_pd ( rho6v.v, rho7v.v ); + n_remainder = n - i; + bool hasInf = false; + if ( ( n_remainder > 0 ) ) + { + // Put first the most likely to happen to avoid evaluations on if statements. + for (i = 0; i < n_remainder; i++) + { + // Get real and imaginary component of the vector element. + double chi_r, chi_i; + bli_zdgets(*xt, chi_r, chi_i); + + // Start with accumulating the real component of the vector element. + abs_chi = bli_fabs( chi_r ); + // If any of the elements is NaN, then return NaN as a result. + if ( bli_isnan( abs_chi ) ) + { + *norm = abs_chi; + return; + } + // Else, if any of the elements is an Inf, then return +Inf as a result. + if ( bli_isinf( abs_chi ) ) + { + *norm = abs_chi; + // Instead of returning immediately, use this flag + // to denote that there is an Inf element in the vector. + // That is used to avoid cases where there is a NaN which comes + // after an Inf. + hasInf = true; + } + // Most likely case: medium values, not over/under-flow. + if ( ( abs_chi <= thres_big ) && ( abs_chi >= thres_sml ) ) + { + sum_med += abs_chi * abs_chi; + } + // Case where there could be an overflow. Scaling is required. + else if ( abs_chi > thres_big ) + { + sum_big += ( abs_chi * scale_big ) * ( abs_chi * scale_big ); + isbig = true; + } + // Case where there could be an underflow. Scaling is required. + else if ( ( !isbig ) && ( abs_chi < thres_sml ) ) + { + sum_sml += ( abs_chi * scale_sml ) * ( abs_chi * scale_sml ); + } - rho4v.v = _mm256_add_pd ( rho0v.v, rho1v.v ); - rho5v.v = _mm256_add_pd ( rho2v.v, rho3v.v ); + // Accumulate the imaginary component of the vector element. + abs_chi = bli_fabs( chi_i ); + // If any of the elements is NaN, then return NaN as a result. + if ( bli_isnan( abs_chi ) ) + { + *norm = abs_chi; + return; + } + // Else, if any of the elements is an Inf, then return +Inf as a result. + if ( bli_isinf( abs_chi ) ) + { + *norm = abs_chi; + // Instead of returning immediately, use this flag + // to denote that there is an Inf element in the vector. + // That is used to avoid cases where there is a NaN which comes + // after an Inf. + hasInf = true; + } + // Most likely case: medium values, not over/under-flow. + if ( ( abs_chi <= thres_big ) && ( abs_chi >= thres_sml ) ) + { + sum_med += abs_chi * abs_chi; + } + // Case where there could be an overflow. Scaling is required. + else if ( abs_chi > thres_big ) + { + sum_big += ( abs_chi * scale_big ) * ( abs_chi * scale_big ); + isbig = true; + } + // Case where there could be an underflow. Scaling is required. + else if ( ( !isbig ) && ( abs_chi < thres_sml ) ) + { + sum_sml += ( abs_chi * scale_sml ) * ( abs_chi * scale_sml ); + } - rho6v.v = _mm256_add_pd ( rho4v.v, rho5v.v ); + xt++; + } + } - rho7v.v = _mm256_hadd_pd( rho6v.v, rho6v.v ); + // Early return if there is an Inf. + if ( hasInf ) return; - /*rem_sumsq will have sum of squares of n_remainder < 4 cases . - Accumulate all the sum of squares to sumsq*/ - sumsq = rem_sumsq + rho7v.d[0] + rho7v.d[2]; + // Combine accumulators. + if ( isbig ) + { + // Combine sum_big and sum_med if sum_med > 0. + if ( sum_med > 0.0 ) + { + sum_big += ( sum_med * scale_big ) * scale_big; + } + scale = 1.0 / scale_big; + sumsq = sum_big; + } - PASTEMAC(d,sqrt2s)( sumsq, *norm ); + else if ( sum_sml > 0.0 ) + { + // Combine sum_med and sum_sml if sum_sml>0. + if ( sum_med > 0.0 ) + { + sum_med = sqrt( sum_med ); + sum_sml = sqrt( sum_sml ) / scale_sml; + double ymin, ymax; + if ( sum_sml > sum_med ) + { + ymin = sum_med; + ymax = sum_sml; + } + else + { + ymin = sum_sml; + ymax = sum_med; + } + scale = 1.0; + sumsq = ymax * ymax * ( 1.0 + ( ymin / ymax ) * ( ymin / ymax ) ); + } + else + { + scale = 1.0 / scale_sml; + sumsq = sum_sml; + } + } + else + { + // If all values are mid-range: + scale = 1.0; + sumsq = sum_med; + } + + *norm = scale * sqrt( sumsq ); - if ((incx != 1) && bli_mem_is_alloc( &mem_bufX )) + if ( ( incx != 1 ) && bli_mem_is_alloc( &mem_bufX ) ) { #ifdef BLIS_ENABLE_MEM_TRACING - printf( "bli_dnorm2fv_unb_var1(): releasing mem pool block\n" ); + printf( "bli_dznorm2fv_unb_var1(): releasing mem pool block\n" ); #endif - /* Return the buffer to pool*/ - bli_membrk_release(&rntm , &mem_bufX); + // Return the buffer to pool. + bli_membrk_release( &rntm , &mem_bufX ); } - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3); - return ; + + AOCL_DTL_TRACE_EXIT( AOCL_DTL_LEVEL_TRACE_3 ); + + return; } -#endif diff --git a/kernels/zen/2/CMakeLists.txt b/kernels/zen/2/CMakeLists.txt index d4ad0143ed..85ad4bfd5a 100644 --- a/kernels/zen/2/CMakeLists.txt +++ b/kernels/zen/2/CMakeLists.txt @@ -4,8 +4,17 @@ target_sources("${PROJECT_NAME}" PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemv_zen_ref.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_her2_zen_int_4.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemv_zen_int_4.c + ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemv_zen_int_4.c ) - - +# Select AMD specific sources for AMD configurations. +if(${TARGET_ARCH} STREQUAL zen OR +${TARGET_ARCH} STREQUAL zen2 OR +${TARGET_ARCH} STREQUAL zen3 OR +${TARGET_ARCH} STREQUAL zen4 OR +${TARGET_ARCH} STREQUAL amdzen) + target_sources("${PROJECT_NAME}" + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/bli_her_zen_int_amd.c + ) +endif() \ No newline at end of file diff --git a/kernels/zen/2/bli_gemv_zen_int_4.c b/kernels/zen/2/bli_gemv_zen_int_4.c index 74904605ee..a4bdfb4499 100644 --- a/kernels/zen/2/bli_gemv_zen_int_4.c +++ b/kernels/zen/2/bli_gemv_zen_int_4.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/kernels/zen/2/bli_her_zen_int_amd.c b/kernels/zen/2/bli_her_zen_int_amd.c new file mode 100644 index 0000000000..ee259b7e3e --- /dev/null +++ b/kernels/zen/2/bli_her_zen_int_amd.c @@ -0,0 +1,1128 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "immintrin.h" +#include "blis.h" + +/** + * Optimized implementation of ZHER for lower triangular row stored & + * upper triangular column stored matrix. + * This kernel performs: + * A := A + conj?(alpha) * conj?(x) * conj?(x)^H + * where, + * A is an m x m hermitian matrix stored in upper/lower triangular + * x is a vector of length m + * alpha is a scalar + */ +void bli_zher_zen_int_var1 +( + uplo_t uplo, + conj_t conjx, + conj_t conjh, + dim_t m, + dcomplex* restrict alpha, + dcomplex* restrict x, inc_t incx, + dcomplex* restrict c, inc_t rs_c, inc_t cs_c, + cntx_t* restrict cntx +) +{ + double xcR, xcI; + double xhermcR, xhermcI; + double alphaR; + double interR, interI; + + dcomplex* xc; + dcomplex* xhermc; + dcomplex* cc; + + __m256d alphaRv; + __m256d ymm0, ymm1, ymm4, ymm5; + __m256d ymm6, ymm7, ymm8, ymm9, ymm10, ymm11; + __m256d ymm0_shuf, ymm1_shuf; + __m256d conj_mulv; + + dim_t conj_multiplier; + + inc_t rs_ct, cs_ct; + dim_t i = 0; + dim_t j = 0; + + alphaR = alpha->real; + + // The algorithm is expressed in terms of lower triangular case; + // the upper triangular case is supported by swapping the row and column + // strides of A & toggling the conj parameter. + if ( bli_is_lower( uplo ) ) + { + rs_ct = rs_c; + cs_ct = cs_c; + } + else /* if ( bli_is_upper( uplo ) ) */ + { + rs_ct = cs_c; + cs_ct = rs_c; + conjx = bli_apply_conj( conjh, conjx ); + } + + // Enabling conj_multiplier for scalar multiplication based on conjx + if ( !bli_is_conj(conjx) ) conj_multiplier = 1; + else conj_multiplier = -1; + + // Broadcasting real values of alpha based on conjx + // alphaRv = aR aR aR aR + if ( bli_is_conj( conjx ) ) alphaRv = _mm256_broadcast_sd( &alphaR ); + else alphaRv = _mm256_set_pd( -alphaR, alphaR, -alphaR, alphaR ); + + conj_mulv = _mm256_set_pd( conj_multiplier, -1 * conj_multiplier, conj_multiplier, -1 * conj_multiplier ); + + /********* DIAGONAL ELEMENTS *********/ + // Solving for the diagonal elements using a scalar loop + for ( i = 0; i < m; i++ ) + { + xc = x + i*incx; + xcR = xc->real; + xcI = xc->imag; + xhermcR = xc->real; + xhermcI = xc->imag; + + xcR = alphaR * xcR; + xcI = alphaR * xcI; + interR = xcR * xhermcR + xcI * xhermcI; + + cc = c + (i)*rs_ct + (i)*cs_ct; + cc->real += interR; + cc->imag = 0; + } + + // Vectorized loop + for ( i = 0; ( i + 3 ) < m; i += 4 ) + { + // Loading elements of x to ymm0-1 for computing xherm vector + // ymm0 = x0R x0I x1R x1I + // ymm1 = x2R x2I x3R x3I + ymm0 = _mm256_loadu_pd( (double*)(x + i*incx) ); + ymm1 = _mm256_loadu_pd( (double*)(x + (i + 2)*incx) ); + + // Scaling xherm vector with alpha + // alphaRv = aR aR aR aR + // ymm0 = x0R -x0I x1R -x1I + // ymm1 = x2R -x2I x3R -x3I + // ymm0 * alphaRv = aR.x0R -aR.x0I aR.x1R -aR.x1I + // ymm1 * alphaRv = aR.x2R -aR.x2I aR.x3R -aR.x3I + ymm0 = _mm256_mul_pd( ymm0, alphaRv ); + ymm1 = _mm256_mul_pd( ymm1, alphaRv ); + + // Shuffling xherm vector for multiplication with x vector + // ymm0_shuf = -x0I x0R -x1I x1R + // ymm1_shuf = -x2I x2R -x3I x3R + ymm0_shuf = _mm256_permute_pd( ymm0, 5 ); + ymm1_shuf = _mm256_permute_pd( ymm1, 5 ); + + /********* TRIANGULAR BLOCK *********/ + // Solving the corner elements of the triangular block + // using scalar multiplication + xc = x + (i + 1)*incx; + xcR = xc->real; + xcI = conj_multiplier * xc->imag; + + xhermc = x + (i)*incx; + xhermcR = xhermc->real; + xhermcI = -1 * conj_multiplier * xhermc->imag; + + xcR = alphaR * xcR; + xcI = alphaR * xcI; + interR = xcR * xhermcR - xcI * xhermcI; + interI = xcR * xhermcI + xcI * xhermcR; + + cc = c + (i + 1)*rs_ct + (i + 0)*cs_ct; + cc->real += interR; + cc->imag += interI; + + xc = x + (i + 3)*incx; + xcR = xc->real; + xcI = conj_multiplier * xc->imag; + + xhermc = x + (i + 2)*incx; + xhermcR = xhermc->real; + xhermcI = -1 * conj_multiplier * xhermc->imag; + + xcR = alphaR * xcR; + xcI = alphaR * xcI; + interR = xcR * xhermcR - xcI * xhermcI; + interI = xcR * xhermcI + xcI * xhermcR; + + cc = c + (i + 3)*rs_ct + (i + 2)*cs_ct; + cc->real += interR; + cc->imag += interI; + + // Solving the 2x2 square tile inside the triangular block + // using intrinsics + // Broadcasting elements from x to ymm4-5 + // ymm4 = x2R x2I x2R x2I + // ymm5 = x3R x3I x3R x3I + ymm4 = _mm256_broadcast_pd( (__m128d const*)( x + (i + 2)*incx ) ); + ymm5 = _mm256_broadcast_pd( (__m128d const*)( x + (i + 3)*incx ) ); + + // Loading a tile from matrix + // ymm10 = c20R c20I c21R c21I + // ymm11 = c30R c30I c31R c31I + ymm10 = _mm256_loadu_pd( (double*)( c + (i + 2)*rs_ct + (i)*cs_ct ) ); + ymm11 = _mm256_loadu_pd( (double*)( c + (i + 3)*rs_ct + (i)*cs_ct ) ); + + // Separating the real & imaginary parts of x into ymm4-7 + // ymm6 -> imag of ymm4 + // ymm4 -> real of ymm4 + ymm6 = _mm256_permute_pd( ymm4, 15 ); + ymm4 = _mm256_permute_pd( ymm4, 0 ); + ymm7 = _mm256_permute_pd( ymm5, 15 ); + ymm5 = _mm256_permute_pd( ymm5, 0 ); + + // Applying conjugate to elements of x vector + ymm6 = _mm256_mul_pd( ymm6, conj_mulv ); + ymm7 = _mm256_mul_pd( ymm7, conj_mulv ); + + // Multiplying x vector with x hermitian vector + // and adding the result to the corresponding tile + ymm8 = _mm256_mul_pd( ymm4, ymm0 ); + ymm8 = _mm256_fmadd_pd( ymm6, ymm0_shuf, ymm8 ); + ymm10 = _mm256_add_pd( ymm10, ymm8 ); + + ymm9 = _mm256_mul_pd( ymm5, ymm0 ); + ymm9 = _mm256_fmadd_pd( ymm7, ymm0_shuf, ymm9 ); + ymm11 = _mm256_add_pd( ymm11, ymm9 ); + + // Storing back the results to the matrix + _mm256_storeu_pd( (double*)( c + (i + 2)*rs_ct + (i)*cs_ct ), ymm10 ); + _mm256_storeu_pd( (double*)( c + (i + 3)*rs_ct + (i)*cs_ct ), ymm11 ); + + /********* SQUARE BLOCK *********/ + // Solving a 4x4 square block of matrix using intrinsics + for ( j = (i + 4); (j + 3) < m; j += 4) + { + // Broadcasting elements from x to ymm4-5 + ymm4 = _mm256_broadcast_pd( (__m128d const*)( x + (j )*incx ) ); + ymm5 = _mm256_broadcast_pd( (__m128d const*)( x + (j + 1)*incx ) ); + + // Loading a tile from matrix + ymm10 = _mm256_loadu_pd( (double*)( c + j*rs_ct + (i )*cs_ct ) ); + ymm11 = _mm256_loadu_pd( (double*)( c + j*rs_ct + (i + 2)*cs_ct ) ); + + // Separating the real & imaginary parts of x into ymm4-7 + // ymm6 -> imag of ymm4 + // ymm4 -> real of ymm4 + ymm6 = _mm256_permute_pd( ymm4, 15 ); + ymm4 = _mm256_permute_pd( ymm4, 0 ); + ymm7 = _mm256_permute_pd( ymm5, 15 ); + ymm5 = _mm256_permute_pd( ymm5, 0 ); + + // Applying conjugate to elements of x vector + ymm6 = _mm256_mul_pd( ymm6, conj_mulv ); + ymm7 = _mm256_mul_pd( ymm7, conj_mulv ); + + // Multiplying x vector with x hermitian vector + // and adding the result to the corresponding tile + ymm8 = _mm256_mul_pd( ymm4, ymm0 ); + ymm8 = _mm256_fmadd_pd( ymm6, ymm0_shuf, ymm8 ); + ymm10 = _mm256_add_pd( ymm10, ymm8 ); + + ymm9 = _mm256_mul_pd( ymm4, ymm1 ); + ymm9 = _mm256_fmadd_pd( ymm6, ymm1_shuf, ymm9 ); + ymm11 = _mm256_add_pd( ymm11, ymm9 ); + + // Storing back the results to the matrix + _mm256_storeu_pd + ( + (double*)( c + (j)*rs_ct + (i)*cs_ct ), + ymm10 + ); + _mm256_storeu_pd + ( + (double*)( c + (j)*rs_ct + (i + 2)*cs_ct ), + ymm11 + ); + + // Loading a tile from matrix + ymm10 = _mm256_loadu_pd + ( + (double*)( c + (j + 1)*rs_ct + (i)*cs_ct ) + ); + ymm11 = _mm256_loadu_pd + ( + (double*)( c + (j + 1)*rs_ct + (i + 2)*cs_ct ) + ); + + // Multiplying x vector with x hermitian vector + // and adding the result to the corresponding tile + ymm8 = _mm256_mul_pd( ymm5, ymm0 ); + ymm8 = _mm256_fmadd_pd( ymm7, ymm0_shuf, ymm8 ); + ymm10 = _mm256_add_pd( ymm10, ymm8 ); + + ymm9 = _mm256_mul_pd( ymm5, ymm1 ); + ymm9 = _mm256_fmadd_pd( ymm7, ymm1_shuf, ymm9 ); + ymm11 = _mm256_add_pd( ymm11, ymm9 ); + + // Storing back the results to the matrix + _mm256_storeu_pd + ( + (double*)( c + (j + 1)*rs_ct + (i)*cs_ct ), + ymm10 + ); + _mm256_storeu_pd + ( + (double*)( c + (j + 1)*rs_ct + (i + 2)*cs_ct ), + ymm11 + ); + + // Broadcasting elements from x to ymm4-5 + ymm4 = _mm256_broadcast_pd( (__m128d const*)( x + (j + 2)*incx ) ); + ymm5 = _mm256_broadcast_pd( (__m128d const*)( x + (j + 3)*incx ) ); + + // Loading a tile from matrix + ymm10 = _mm256_loadu_pd + ( + (double*)( c + (j + 2)*rs_ct + (i)*cs_ct ) + ); + ymm11 = _mm256_loadu_pd + ( + (double*)( c + (j + 2)*rs_ct + (i + 2)*cs_ct ) + ); + + // Separating the real & imaginary parts of x into ymm4-7 + // ymm6 -> imag of ymm4 + // ymm4 -> real of ymm4 + ymm6 = _mm256_permute_pd( ymm4, 15 ); + ymm4 = _mm256_permute_pd( ymm4, 0 ); + ymm7 = _mm256_permute_pd( ymm5, 15 ); + ymm5 = _mm256_permute_pd( ymm5, 0 ); + + // Applying conjugate to elements of x vector + ymm6 = _mm256_mul_pd( ymm6, conj_mulv ); + ymm7 = _mm256_mul_pd( ymm7, conj_mulv ); + + // Multiplying x vector with x hermitian vector + // and adding the result to the corresponding tile + ymm8 = _mm256_mul_pd( ymm4, ymm0 ); + ymm8 = _mm256_fmadd_pd( ymm6, ymm0_shuf, ymm8 ); + ymm10 = _mm256_add_pd( ymm10, ymm8 ); + + ymm9 = _mm256_mul_pd( ymm4, ymm1 ); + ymm9 = _mm256_fmadd_pd( ymm6, ymm1_shuf, ymm9 ); + ymm11 = _mm256_add_pd( ymm11, ymm9 ); + + // Storing back the results to the matrix + _mm256_storeu_pd + ( + (double*)( c + (j + 2)*rs_ct + (i)*cs_ct ), + ymm10 + ); + _mm256_storeu_pd + ( + (double*)( c + (j + 2)*rs_ct + (i + 2)*cs_ct ), + ymm11 + ); + + // Loading a tile from matrix + ymm10 = _mm256_loadu_pd + ( + (double*)( c + (j + 3)*rs_ct + (i)*cs_ct ) + ); + ymm11 = _mm256_loadu_pd + ( + (double*)( c + (j + 3)*rs_ct + (i + 2)*cs_ct ) + ); + + // Multiplying x vector with x hermitian vector + // and adding the result to the corresponding tile + ymm8 = _mm256_mul_pd( ymm5, ymm0 ); + ymm8 = _mm256_fmadd_pd( ymm7, ymm0_shuf, ymm8 ); + ymm10 = _mm256_add_pd( ymm10, ymm8 ); + + ymm9 = _mm256_mul_pd( ymm5, ymm1 ); + ymm9 = _mm256_fmadd_pd( ymm7, ymm1_shuf, ymm9 ); + ymm11 = _mm256_add_pd( ymm11, ymm9 ); + + // Storing back the results to the matrix + _mm256_storeu_pd + ( + (double*)( c + (j + 3)*rs_ct + (i)*cs_ct ), + ymm10 + ); + _mm256_storeu_pd + ( + (double*)( c + (j + 3)*rs_ct + (i + 2)*cs_ct ), + ymm11 + ); + } + + // Solving a 2x2 square block of matrix using intrinsics + for ( ; (j + 1) < m; j += 2) + { + // Broadcasting elements from x to ymm4-5 + ymm4 = _mm256_broadcast_pd( (__m128d const*)( x + (j)*incx ) ); + ymm5 = _mm256_broadcast_pd( (__m128d const*)( x + (j + 1)*incx ) ); + + // Loading a tile from matrix + ymm10 = _mm256_loadu_pd( (double*)( c + j*rs_ct + (i)*cs_ct ) ); + ymm11 = _mm256_loadu_pd( (double*)( c + j*rs_ct + (i + 2)*cs_ct ) ); + + // Separating the real & imaginary parts of x into ymm4-7 + // ymm6 -> imag of ymm4 + // ymm4 -> real of ymm4 + ymm6 = _mm256_permute_pd( ymm4, 15 ); + ymm4 = _mm256_permute_pd( ymm4, 0 ); + ymm7 = _mm256_permute_pd( ymm5, 15 ); + ymm5 = _mm256_permute_pd( ymm5, 0 ); + + // Applying conjugate to elements of x vector + ymm6 = _mm256_mul_pd( ymm6, conj_mulv ); + ymm7 = _mm256_mul_pd( ymm7, conj_mulv ); + + // Multiplying x vector with x hermitian vector + // and adding the result to the corresponding tile + ymm8 = _mm256_mul_pd( ymm4, ymm0 ); + ymm8 = _mm256_fmadd_pd( ymm6, ymm0_shuf, ymm8 ); + ymm10 = _mm256_add_pd( ymm10, ymm8 ); + + ymm9 = _mm256_mul_pd( ymm4, ymm1 ); + ymm9 = _mm256_fmadd_pd( ymm6, ymm1_shuf, ymm9 ); + ymm11 = _mm256_add_pd( ymm11, ymm9 ); + + // Storing back the results to the matrix + _mm256_storeu_pd + ( + (double*)( c + (j)*rs_ct + (i)*cs_ct ), + ymm10 + ); + _mm256_storeu_pd + ( + (double*)( c + (j)*rs_ct + (i + 2)*cs_ct ), + ymm11 + ); + + // Loading a tile from matrix + ymm10 = _mm256_loadu_pd + ( + (double*)( c + (j + 1)*rs_ct + (i)*cs_ct ) + ); + ymm11 = _mm256_loadu_pd + ( + (double*)( c + (j + 1)*rs_ct + (i + 2)*cs_ct ) + ); + + // Multiplying x vector with x hermitian vector + // and adding the result to the corresponding tile + ymm8 = _mm256_mul_pd( ymm5, ymm0 ); + ymm8 = _mm256_fmadd_pd( ymm7, ymm0_shuf, ymm8 ); + ymm10 = _mm256_add_pd( ymm10, ymm8 ); + + ymm9 = _mm256_mul_pd( ymm5, ymm1 ); + ymm9 = _mm256_fmadd_pd( ymm7, ymm1_shuf, ymm9 ); + ymm11 = _mm256_add_pd( ymm11, ymm9 ); + + // Storing back the results to the matrix + _mm256_storeu_pd + ( + (double*)( c + (j + 1)*rs_ct + (i)*cs_ct ), + ymm10 + ); + _mm256_storeu_pd + ( + (double*)( c + (j + 1)*rs_ct + (i + 2)*cs_ct ), + ymm11 + ); + } + + for ( ; j < m; j++ ) + { + // Broadcasting elements from x to ymm4-5 + ymm4 = _mm256_broadcast_pd( (__m128d const*)( x + (j)*incx ) ); + + // Loading a tile from matrix + ymm10 = _mm256_loadu_pd( (double*)( c + j*rs_ct + (i)*cs_ct ) ); + ymm11 = _mm256_loadu_pd( (double*)( c + j*rs_ct + (i + 2)*cs_ct ) ); + + // Separating the real & imaginary parts of x into ymm4-7 + // ymm6 -> imag of ymm4 + // ymm4 -> real of ymm4 + ymm6 = _mm256_permute_pd( ymm4, 15 ); + ymm4 = _mm256_permute_pd( ymm4, 0 ); + ymm7 = _mm256_permute_pd( ymm5, 15 ); + ymm5 = _mm256_permute_pd( ymm5, 0 ); + + // Applying conjugate to elements of x vector + ymm6 = _mm256_mul_pd( ymm6, conj_mulv ); + ymm7 = _mm256_mul_pd( ymm7, conj_mulv ); + + // Multiplying x vector with x hermitian vector + // and adding the result to the corresponding tile + ymm8 = _mm256_mul_pd( ymm4, ymm0 ); + ymm8 = _mm256_fmadd_pd( ymm6, ymm0_shuf, ymm8 ); + ymm10 = _mm256_add_pd( ymm10, ymm8 ); + + ymm9 = _mm256_mul_pd( ymm4, ymm1 ); + ymm9 = _mm256_fmadd_pd( ymm6, ymm1_shuf, ymm9 ); + ymm11 = _mm256_add_pd( ymm11, ymm9 ); + + // Storing back the results to the matrix + _mm256_storeu_pd + ( + (double*)( c + (j)*rs_ct + (i)*cs_ct ), + ymm10 + ); + _mm256_storeu_pd + ( + (double*)( c + (j)*rs_ct + (i + 2)*cs_ct ), + ymm11 + ); + } + } + + // Solving the remaining blocks of matrix + for ( ; ( i + 1 ) < m; i += 2 ) + { + // Solving the corner elements + xc = x + (i + 1)*incx; + xcR = xc->real; + xcI = conj_multiplier * xc->imag; + + xhermc = x + i*incx; + xhermcR = xhermc->real; + xhermcI = -1 * conj_multiplier * xhermc->imag; + + xcR = alphaR * xcR; + xcI = alphaR * xcI; + interR = xcR * xhermcR - xcI * xhermcI; + interI = xcR * xhermcI + xcI * xhermcR; + + cc = c + (i + 1)*rs_ct + i*cs_ct; + cc->real += interR; + cc->imag += interI; + + // Loading elements of x to ymm0 for computing xherm vector + ymm0 = _mm256_loadu_pd( (double*)( x + i*incx ) ); + + // Scaling xherm vector with alpha + ymm0 = _mm256_mul_pd( ymm0, alphaRv ); + + // Shuffling xherm vector for multiplication with x vector + ymm0_shuf = _mm256_permute_pd( ymm0, 5 ); + + /********* SQUARE BLOCK *********/ + // Solving a 2x2 square block of matrix using intrinsics + for ( j = ( i + 2 ); j < m; j++ ) + { + // Broadcasting elements from x to ymm4 + ymm4 = _mm256_broadcast_pd( (__m128d const*)( x + (j)*incx ) ); + + // Loading a tile from matrix + ymm10 = _mm256_loadu_pd( (double*)( c + (j)*rs_ct + (i)*cs_ct ) ); + + // Separating the real & imaginary parts of x into ymm4-7 + // ymm6 -> imag of ymm4 + // ymm4 -> real of ymm4 + ymm6 = _mm256_permute_pd( ymm4, 15 ); + ymm4 = _mm256_permute_pd( ymm4, 0 ); + + // Applying conjugate to elements of x vector + ymm6 = _mm256_mul_pd( ymm6, conj_mulv ); + + // Multiplying x vector with x hermitian vector + // and adding the result to the corresponding tile + ymm8 = _mm256_mul_pd( ymm4, ymm0 ); + ymm8 = _mm256_fmadd_pd( ymm6, ymm0_shuf, ymm8 ); + ymm10 = _mm256_add_pd( ymm10, ymm8 ); + + // Storing back the results to the matrix + _mm256_storeu_pd( (double*)( c + (j)*rs_ct + (i)*cs_ct ), ymm10 ); + } + } +} + +/** + * Optimized implementation of ZHER for lower triangular column stored & + * upper triangular row stored matrix. + * This kernel performs: + * A := A + conj?(alpha) * conj?(x) * conj?(x)^H + * where, + * A is an m x m hermitian matrix stored in upper/lower triangular + * x is a vector of length m + * alpha is a scalar + */ +void bli_zher_zen_int_var2 +( + uplo_t uplo, + conj_t conjx, + conj_t conjh, + dim_t m, + dcomplex* alpha, + dcomplex* x, inc_t incx, + dcomplex* c, inc_t rs_c, inc_t cs_c, + cntx_t* cntx +) +{ + double xcR, xcI; + double xhermcR, xhermcI; + double alphaR; + double interR, interI; + + dcomplex* xc; + dcomplex* xhermc; + dcomplex* cc; + + __m256d alphaRv; + __m256d ymm0, ymm1, ymm2, ymm3, ymm4, ymm5; + __m256d ymm6, ymm7, ymm8, ymm9, ymm10, ymm11; + __m256d ymm0_shuf, ymm1_shuf, ymm2_shuf, ymm3_shuf; + + dim_t conj_multiplier; + + inc_t rs_ct, cs_ct; + dim_t i = 0; + dim_t j = 0; + + alphaR = alpha->real; + + // The algorithm is expressed in terms of lower triangular case; + // the upper triangular case is supported by swapping the row and column + // strides of A & toggling the conj parameter. + if ( bli_is_lower( uplo ) ) + { + rs_ct = rs_c; + cs_ct = cs_c; + } + else /* if ( bli_is_upper( uplo ) ) */ + { + rs_ct = cs_c; + cs_ct = rs_c; + conjx = bli_apply_conj( conjh, conjx ); + } + + // Enabling conj_multiplier for scalar multiplication based on conjx + if ( !bli_is_conj(conjx) ) conj_multiplier = 1; + else conj_multiplier = -1; + + // Broadcasting real values of alpha based on conjx + // alphaRv = aR aR aR aR + if ( bli_is_conj( conjx ) ) alphaRv = _mm256_broadcast_sd( &alphaR ); + else alphaRv = _mm256_set_pd( -alphaR, alphaR, -alphaR, alphaR ); + + __m256d conj_mulv = _mm256_set_pd + ( + conj_multiplier, + -1 * conj_multiplier, + conj_multiplier, + -1 * conj_multiplier + ); + + /********* DIAGONAL ELEMENTS *********/ + // Solving for the diagonal elements using a scalar loop + for ( i = 0; i < m; i++ ) + { + xc = x + i*incx; + xcR = xc->real; + xcI = xc->imag; + xhermcR = xc->real; + xhermcI = xc->imag; + + xcR = alphaR * xcR; + xcI = alphaR * xcI; + interR = xcR * xhermcR + xcI * xhermcI; + + cc = c + (i)*rs_ct + (i)*cs_ct; + cc->real += interR; + cc->imag = 0; + } + + // Vectorized loop + for ( i = 0; ( i + 3 ) < m; i += 4 ) + { + // Broadcasting elements of x to ymm0-1 for computing xherm vector + // ymm0 = x0R x0I x1R x1I + ymm0 = _mm256_broadcast_pd( (__m128d const*)( x + i*incx ) ); + ymm1 = _mm256_broadcast_pd( (__m128d const*)( x + (i + 1)*incx ) ); + ymm2 = _mm256_broadcast_pd( (__m128d const*)( x + (i + 2)*incx ) ); + ymm3 = _mm256_broadcast_pd( (__m128d const*)( x + (i + 3)*incx ) ); + + // Scaling xherm vector with alpha + // alphaRv = aR aR aR aR + // ymm0 = x0R -x0I x1R -x1I + // ymm0 * alphaRv = aR.x0R -aR.x0I aR.x1R -aR.x1I + ymm0 = _mm256_mul_pd( ymm0, alphaRv ); + ymm1 = _mm256_mul_pd( ymm1, alphaRv ); + ymm2 = _mm256_mul_pd( ymm2, alphaRv ); + ymm3 = _mm256_mul_pd( ymm3, alphaRv ); + + // Shuffling xherm vector for multiplication with x vector + // ymm0_shuf = -x0I x0R -x1I x1R + ymm0_shuf = _mm256_permute_pd( ymm0, 5 ); + ymm1_shuf = _mm256_permute_pd( ymm1, 5 ); + ymm2_shuf = _mm256_permute_pd( ymm2, 5 ); + ymm3_shuf = _mm256_permute_pd( ymm3, 5 ); + + /********* TRIANGULAR BLOCK *********/ + // Solving the corner elements of the triangular block + // using scalar multiplication + xc = x + (i + 1)*incx; + xcR = xc->real; + xcI = conj_multiplier * xc->imag; + + xhermc = x + (i)*incx; + xhermcR = xhermc->real; + xhermcI = -1 * conj_multiplier * xhermc->imag; + + xcR = alphaR * xcR; + xcI = alphaR * xcI; + interR = xcR * xhermcR - xcI * xhermcI; + interI = xcR * xhermcI + xcI * xhermcR; + + cc = c + (i + 1)*rs_ct + (i + 0)*cs_ct; + cc->real += interR; + cc->imag += interI; + + xc = x + (i + 3)*incx; + xcR = xc->real; + xcI = conj_multiplier * xc->imag; + + xhermc = x + (i + 2)*incx; + xhermcR = xhermc->real; + xhermcI = -1 * conj_multiplier * xhermc->imag; + + xcR = alphaR * xcR; + xcI = alphaR * xcI; + interR = xcR * xhermcR - xcI * xhermcI; + interI = xcR * xhermcI + xcI * xhermcR; + + cc = c + (i + 3)*rs_ct + (i + 2)*cs_ct; + cc->real += interR; + cc->imag += interI; + + // Solving the 2x2 square tile inside the triangular block + // using intrinsics + // Loading elements from x to ymm4 + // ymm4 = x2R x2I x2R x2I + ymm4 = _mm256_loadu_pd( (double*)( x + (i + 2)*incx ) ); + + // Loading a tile from matrix + // ymm10 = c20R c20I c21R c21I + // ymm11 = c30R c30I c31R c31I + ymm10 = _mm256_loadu_pd + ( + (double*)( c + (i + 2)*rs_ct + (i)*cs_ct ) + ); + ymm11 = _mm256_loadu_pd + ( + (double*)( c + (i + 2)*rs_ct + (i + 1)*cs_ct ) + ); + + // Separating the real & imaginary parts of x into ymm4-7 + // ymm6 -> imag of ymm4 + // ymm4 -> real of ymm4 + ymm6 = _mm256_permute_pd( ymm4, 15 ); + ymm4 = _mm256_permute_pd( ymm4, 0 ); + + // Applying conjugate to elements of x vector + ymm6 = _mm256_mul_pd( ymm6, conj_mulv ); + + // Multiplying x vector with x hermitian vector + // and adding the result to the corresponding tile + ymm8 = _mm256_mul_pd( ymm4, ymm0 ); + ymm8 = _mm256_fmadd_pd( ymm6, ymm0_shuf, ymm8 ); + ymm10 = _mm256_add_pd( ymm10, ymm8 ); + + ymm9 = _mm256_mul_pd( ymm4, ymm1 ); + ymm9 = _mm256_fmadd_pd( ymm6, ymm1_shuf, ymm9 ); + ymm11 = _mm256_add_pd( ymm11, ymm9 ); + + // Storing back the results to the matrix + _mm256_storeu_pd + ( + (double*)( c + (i + 2)*rs_ct + (i)*cs_ct ), + ymm10 + ); + _mm256_storeu_pd + ( + (double*)( c + (i + 2)*rs_ct + (i + 1)*cs_ct ), + ymm11 + ); + + /********* SQUARE BLOCK *********/ + // Solving a 4x4 square block of matrix using intrinsics + for ( j = (i + 4); (j + 3) < m; j += 4) + { + // Loading elements from x to ymm4-5 + ymm4 = _mm256_loadu_pd( (double*)( x + j*incx ) ); + ymm5 = _mm256_loadu_pd( (double*)( x + (j + 2)*incx ) ); + + // Separating the real & imaginary parts of x into ymm4-7 + // ymm6 -> imag of ymm4 + // ymm4 -> real of ymm4 + ymm6 = _mm256_permute_pd( ymm4, 15 ); + ymm4 = _mm256_permute_pd( ymm4, 0 ); + ymm7 = _mm256_permute_pd( ymm5, 15 ); + ymm5 = _mm256_permute_pd( ymm5, 0 ); + + // Applying conjugate to elements of x vector + ymm6 = _mm256_mul_pd( ymm6, conj_mulv ); + ymm7 = _mm256_mul_pd( ymm7, conj_mulv ); + + // Loading a tile from matrix + ymm10 = _mm256_loadu_pd + ( + (double*)( c + (j)*rs_ct + (i)*cs_ct ) + ); + ymm11 = _mm256_loadu_pd + ( + (double*)( c + (j + 2)*rs_ct + (i)*cs_ct ) + ); + + // Multiplying x vector with x hermitian vector + // and adding the result to the corresponding tile + ymm8 = _mm256_mul_pd( ymm4, ymm0 ); + ymm9 = _mm256_mul_pd( ymm5, ymm0 ); + ymm8 = _mm256_fmadd_pd( ymm6, ymm0_shuf, ymm8 ); + ymm9 = _mm256_fmadd_pd( ymm7, ymm0_shuf, ymm9 ); + ymm10 = _mm256_add_pd( ymm10, ymm8 ); + ymm11 = _mm256_add_pd( ymm11, ymm9 ); + + // Storing back the results to the matrix + _mm256_storeu_pd + ( + (double*)( c + (j)*rs_ct + (i)*cs_ct ), + ymm10 + ); + _mm256_storeu_pd + ( + (double*)( c + (j + 2)*rs_ct + (i)*cs_ct ), + ymm11 + ); + + // Loading a tile from matrix + ymm10 = _mm256_loadu_pd + ( + (double*)( c + (j)*rs_ct + (i + 1)*cs_ct ) + ); + ymm11 = _mm256_loadu_pd + ( + (double*)( c + (j + 2)*rs_ct + (i + 1)*cs_ct ) + ); + + // Multiplying x vector with x hermitian vector + // and adding the result to the corresponding tile + ymm8 = _mm256_mul_pd( ymm4, ymm1 ); + ymm9 = _mm256_mul_pd( ymm5, ymm1 ); + ymm8 = _mm256_fmadd_pd( ymm6, ymm1_shuf, ymm8 ); + ymm9 = _mm256_fmadd_pd( ymm7, ymm1_shuf, ymm9 ); + ymm10 = _mm256_add_pd( ymm10, ymm8 ); + ymm11 = _mm256_add_pd( ymm11, ymm9 ); + + // Storing back the results to the matrix + _mm256_storeu_pd + ( + (double*)( c + (j)*rs_ct + (i + 1)*cs_ct ), + ymm10 + ); + _mm256_storeu_pd + ( + (double*)( c + (j + 2)*rs_ct + (i + 1)*cs_ct ), + ymm11 + ); + + // Loading a tile from matrix + ymm10 = _mm256_loadu_pd + ( + (double*)( c + (j)*rs_ct + (i + 2)*cs_ct ) + ); + ymm11 = _mm256_loadu_pd + ( + (double*)( c + (j + 2)*rs_ct + (i + 2)*cs_ct ) + ); + + // Multiplying x vector with x hermitian vector + // and adding the result to the corresponding tile + ymm8 = _mm256_mul_pd( ymm4, ymm2 ); + ymm9 = _mm256_mul_pd( ymm5, ymm2 ); + ymm8 = _mm256_fmadd_pd( ymm6, ymm2_shuf, ymm8 ); + ymm9 = _mm256_fmadd_pd( ymm7, ymm2_shuf, ymm9 ); + ymm10 = _mm256_add_pd( ymm10, ymm8 ); + ymm11 = _mm256_add_pd( ymm11, ymm9 ); + + // Storing back the results to the matrix + _mm256_storeu_pd + ( + (double*)( c + (j)*rs_ct + (i + 2)*cs_ct ), + ymm10 + ); + _mm256_storeu_pd + ( + (double*)( c + (j + 2)*rs_ct + (i + 2)*cs_ct ), + ymm11 + ); + + // Loading a tile from matrix + ymm10 = _mm256_loadu_pd + ( + (double*)( c + (j)*rs_ct + (i + 3)*cs_ct ) + ); + ymm11 = _mm256_loadu_pd + ( + (double*)( c + (j + 2)*rs_ct + (i + 3)*cs_ct ) + ); + + // Multiplying x vector with x hermitian vector + // and adding the result to the corresponding tile + ymm8 = _mm256_mul_pd( ymm4, ymm3 ); + ymm9 = _mm256_mul_pd( ymm5, ymm3 ); + ymm8 = _mm256_fmadd_pd( ymm6, ymm3_shuf, ymm8 ); + ymm9 = _mm256_fmadd_pd( ymm7, ymm3_shuf, ymm9 ); + ymm10 = _mm256_add_pd( ymm10, ymm8 ); + ymm11 = _mm256_add_pd( ymm11, ymm9 ); + + // Storing back the results to the matrix + _mm256_storeu_pd + ( + (double*)( c + (j)*rs_ct + (i + 3)*cs_ct ), + ymm10 + ); + _mm256_storeu_pd + ( + (double*)( c + (j + 2)*rs_ct + (i + 3)*cs_ct ), + ymm11 + ); + } + + // Solving a 2x2 square block of matrix using intrinsics + for ( ; (j + 1) < m; j += 2) + { + // Loading elements from x to ymm4 + ymm4 = _mm256_loadu_pd( (double*)( x + j*incx ) ); + + // Separating the real & imaginary parts of x into ymm4-7 + // ymm6 -> imag of ymm4 + // ymm4 -> real of ymm4 + ymm6 = _mm256_permute_pd( ymm4, 15 ); + ymm4 = _mm256_permute_pd( ymm4, 0 ); + + // Applying conjugate to elements of x vector + ymm6 = _mm256_mul_pd( ymm6, conj_mulv ); + + // Loading a tile from matrix + ymm10 = _mm256_loadu_pd( (double*)( c + (j)*rs_ct + (i)*cs_ct ) ); + + // Multiplying x vector with x hermitian vector + // and adding the result to the corresponding tile + ymm8 = _mm256_mul_pd( ymm4, ymm0 ); + ymm8 = _mm256_fmadd_pd( ymm6, ymm0_shuf, ymm8 ); + ymm10 = _mm256_add_pd( ymm10, ymm8 ); + + // Storing back the results to the matrix + _mm256_storeu_pd( (double*)( c + (j)*rs_ct + (i)*cs_ct ), ymm10 ); + + // Loading a tile from matrix + ymm10 = _mm256_loadu_pd( (double*)( c + j*rs_ct + (i + 1)*cs_ct ) ); + + // Multiplying x vector with x hermitian vector + // and adding the result to the corresponding tile + ymm8 = _mm256_mul_pd( ymm4, ymm1 ); + ymm8 = _mm256_fmadd_pd( ymm6, ymm1_shuf, ymm8 ); + ymm10 = _mm256_add_pd( ymm10, ymm8 ); + + // Storing back the results to the matrix + _mm256_storeu_pd( (double*)( c + j*rs_ct + (i + 1)*cs_ct ), ymm10 ); + + // Loading a tile from matrix + ymm10 = _mm256_loadu_pd( (double*)( c + j*rs_ct + (i + 2)*cs_ct ) ); + + // Multiplying x vector with x hermitian vector + // and adding the result to the corresponding tile + ymm8 = _mm256_mul_pd( ymm4, ymm2 ); + ymm8 = _mm256_fmadd_pd( ymm6, ymm2_shuf, ymm8 ); + ymm10 = _mm256_add_pd( ymm10, ymm8 ); + + // Storing back the results to the matrix + _mm256_storeu_pd( (double*)( c + j*rs_ct + (i + 2)*cs_ct ), ymm10 ); + + // Loading a tile from matrix + ymm10 = _mm256_loadu_pd( (double*)( c + j*rs_ct + (i + 3)*cs_ct ) ); + + // Multiplying x vector with x hermitian vector + // and adding the result to the corresponding tile + ymm8 = _mm256_mul_pd( ymm4, ymm3 ); + ymm8 = _mm256_fmadd_pd( ymm6, ymm3_shuf, ymm8 ); + ymm10 = _mm256_add_pd( ymm10, ymm8 ); + + // Storing back the results to the matrix + _mm256_storeu_pd( (double*)( c + j*rs_ct + (i + 3)*cs_ct ), ymm10 ); + } + + // Calculating for the remaining elements using scalar code + for ( ; j < m; j++ ) + { + xc = x + j*incx; + xcR = xc->real; + xcI = conj_multiplier * xc->imag; + + xhermc = x + i*incx; + xhermcR = xhermc->real; + xhermcI = -1 * conj_multiplier * xhermc->imag; + + xcR = alphaR * xcR; + xcI = alphaR * xcI; + interR = xcR * xhermcR - xcI * xhermcI; + interI = xcR * xhermcI + xcI * xhermcR; + + // c + ((alpha * x) * xherm) + cc = c + (j)*rs_ct + (i)*cs_ct; + cc->real += interR; + cc->imag += interI; + + xc = x + j*incx; + xcR = xc->real; + xcI = conj_multiplier * xc->imag; + + xhermc = x + (i + 1)*incx; + xhermcR = xhermc->real; + xhermcI = -1 * conj_multiplier * xhermc->imag; + + xcR = alphaR * xcR; + xcI = alphaR * xcI; + interR = xcR * xhermcR - xcI * xhermcI; + interI = xcR * xhermcI + xcI * xhermcR; + + // c + ((alpha * x) * xherm) + cc = c + (j)*rs_ct + (i + 1)*cs_ct; + cc->real += interR; + cc->imag += interI; + + xc = x + j*incx; + xcR = xc->real; + xcI = conj_multiplier * xc->imag; + + xhermc = x + (i + 2)*incx; + xhermcR = xhermc->real; + xhermcI = -1 * conj_multiplier * xhermc->imag; + + xcR = alphaR * xcR; + xcI = alphaR * xcI; + interR = xcR * xhermcR - xcI * xhermcI; + interI = xcR * xhermcI + xcI * xhermcR; + + // c + ((alpha * x) * xherm) + cc = c + (j)*rs_ct + (i + 2)*cs_ct; + cc->real += interR; + cc->imag += interI; + + xc = x + j*incx; + xcR = xc->real; + xcI = conj_multiplier * xc->imag; + + xhermc = x + (i + 3)*incx; + xhermcR = xhermc->real; + xhermcI = -1 * conj_multiplier * xhermc->imag; + + xcR = alphaR * xcR; + xcI = alphaR * xcI; + interR = xcR * xhermcR - xcI * xhermcI; + interI = xcR * xhermcI + xcI * xhermcR; + + // c + ((alpha * x) * xherm) + cc = c + (j)*rs_ct + (i + 3)*cs_ct; + cc->real += interR; + cc->imag += interI; + } + } + + for ( ; ( i + 1 ) < m; i += 2 ) + { + /********* TRIANGULAR BLOCK *********/ + // Solving the corner elements of the triangular block + // using scalar multiplication + xc = x + (i + 1)*incx; + xcR = xc->real; + xcI = conj_multiplier * xc->imag; + + xhermc = x + i*incx; + xhermcR = xhermc->real; + xhermcI = -1 * conj_multiplier * xhermc->imag; + + xcR = alphaR * xcR; + xcI = alphaR * xcI; + interR = xcR * xhermcR - xcI * xhermcI; + interI = xcR * xhermcI + xcI * xhermcR; + + cc = c + (i + 1)*rs_ct + i*cs_ct; + cc->real += interR; + cc->imag += interI; + + // Solving the remaining elements in square block + // using scalar code + for ( j = (i + 2); j < m; j++ ) + { + xc = x + j*incx; + xcR = xc->real; + xcI = conj_multiplier * xc->imag; + + xhermc = x + i*incx; + xhermcR = xhermc->real; + xhermcI = -1 * conj_multiplier * xhermc->imag; + + xcR = alphaR * xcR; + xcI = alphaR * xcI; + interR = xcR * xhermcR - xcI * xhermcI; + interI = xcR * xhermcI + xcI * xhermcR; + + // c + ((alpha * x) * xherm) + cc = c + (j)*rs_ct + (i)*cs_ct; + cc->real += interR; + cc->imag += interI; + + xc = x + j*incx; + xcR = xc->real; + xcI = conj_multiplier * xc->imag; + + xhermc = x + (i + 1)*incx; + xhermcR = xhermc->real; + xhermcI = -1 * conj_multiplier * xhermc->imag; + + xcR = alphaR * xcR; + xcI = alphaR * xcI; + interR = xcR * xhermcR - xcI * xhermcI; + interI = xcR * xhermcI + xcI * xhermcR; + + // c + ((alpha * x) * xherm) + cc = c + (j)*rs_ct + (i + 1)*cs_ct; + cc->real += interR; + cc->imag += interI; + } + } +} \ No newline at end of file diff --git a/kernels/zen/3/CMakeLists.txt b/kernels/zen/3/CMakeLists.txt index 80f78b471b..d90e4e3902 100644 --- a/kernels/zen/3/CMakeLists.txt +++ b/kernels/zen/3/CMakeLists.txt @@ -1,12 +1,11 @@ -##Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved.## +##Copyright (C) 2020-2022, Advanced Micro Devices, Inc. All rights reserved.## target_sources("${PROJECT_NAME}" PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemm_small.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_trsm_small.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_dgemm_ref_k1.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemm_sqp_kernels.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemm_sqp.c + ${CMAKE_CURRENT_SOURCE_DIR}/bli_zgemm_ref_k1.c ) add_subdirectory(sup) diff --git a/kernels/zen/3/bli_gemm_small.c b/kernels/zen/3/bli_gemm_small.c index 0cf5c8c5ce..22bb48f737 100644 --- a/kernels/zen/3/bli_gemm_small.c +++ b/kernels/zen/3/bli_gemm_small.c @@ -1951,12 +1951,12 @@ static err_t bli_sgemm_small tA_packed = D_A_pack; #ifdef BLIS_ENABLE_PREFETCH - _mm_prefetch((char*)(tC + 0), _MM_HINT_T0); - _mm_prefetch((char*)(tC + 8), _MM_HINT_T0); - _mm_prefetch((char*)(tC + ldc), _MM_HINT_T0); - _mm_prefetch((char*)(tC + ldc + 8), _MM_HINT_T0); - _mm_prefetch((char*)(tC + 2 * ldc), _MM_HINT_T0); - _mm_prefetch((char*)(tC + 2 * ldc + 8), _MM_HINT_T0); + _mm_prefetch((char*)(tC + 7), _MM_HINT_T0); + _mm_prefetch((char*)(tC + 15), _MM_HINT_T0); + _mm_prefetch((char*)(tC + ldc + 7), _MM_HINT_T0); + _mm_prefetch((char*)(tC + ldc + 15), _MM_HINT_T0); + _mm_prefetch((char*)(tC + 2 * ldc + 7), _MM_HINT_T0); + _mm_prefetch((char*)(tC + 2 * ldc + 15), _MM_HINT_T0); #endif // clear scratch registers. ymm4 = _mm256_setzero_pd(); @@ -2111,12 +2111,12 @@ static err_t bli_sgemm_small tA = tA_packed + row_idx_packed; #ifdef BLIS_ENABLE_PREFETCH - _mm_prefetch((char*)(tC + 0), _MM_HINT_T0); - _mm_prefetch((char*)(tC + 8), _MM_HINT_T0); - _mm_prefetch((char*)(tC + ldc), _MM_HINT_T0); - _mm_prefetch((char*)(tC + ldc + 8), _MM_HINT_T0); - _mm_prefetch((char*)(tC + 2 * ldc), _MM_HINT_T0); - _mm_prefetch((char*)(tC + 2 * ldc + 8), _MM_HINT_T0); + _mm_prefetch((char*)(tC + 7), _MM_HINT_T0); + _mm_prefetch((char*)(tC + 15), _MM_HINT_T0); + _mm_prefetch((char*)(tC + ldc + 7), _MM_HINT_T0); + _mm_prefetch((char*)(tC + ldc + 15), _MM_HINT_T0); + _mm_prefetch((char*)(tC + 2 * ldc + 7), _MM_HINT_T0); + _mm_prefetch((char*)(tC + 2 * ldc + 15), _MM_HINT_T0); #endif // clear scratch registers. ymm4 = _mm256_setzero_pd(); @@ -4513,12 +4513,12 @@ err_t bli_dgemm_small_At tA = tA_packed + row_idx_packed; #ifdef BLIS_ENABLE_PREFETCH - _mm_prefetch((char*)(tC + 0), _MM_HINT_T0); - _mm_prefetch((char*)(tC + 8), _MM_HINT_T0); - _mm_prefetch((char*)(tC + ldc), _MM_HINT_T0); - _mm_prefetch((char*)(tC + ldc + 8), _MM_HINT_T0); - _mm_prefetch((char*)(tC + 2 * ldc), _MM_HINT_T0); - _mm_prefetch((char*)(tC + 2 * ldc + 8), _MM_HINT_T0); + _mm_prefetch((char*)(tC + 7), _MM_HINT_T0); + _mm_prefetch((char*)(tC + 15), _MM_HINT_T0); + _mm_prefetch((char*)(tC + ldc + 7), _MM_HINT_T0); + _mm_prefetch((char*)(tC + ldc + 15), _MM_HINT_T0); + _mm_prefetch((char*)(tC + 2 * ldc + 7), _MM_HINT_T0); + _mm_prefetch((char*)(tC + 2 * ldc + 15), _MM_HINT_T0); #endif // clear scratch registers. ymm4 = _mm256_setzero_pd(); @@ -7274,7 +7274,7 @@ err_t bli_zgemm_small } m_remainder = M - row_idx; - if ((m_remainder == 3)) + if (m_remainder == 3) { m_remainder -= 3; __m128d xmm0; @@ -8213,7 +8213,7 @@ err_t bli_zgemm_small _mm_storeu_pd((double *)(tC + 2), xmm0); } } - if ((m_remainder == 2)) + if (m_remainder == 2) { m_remainder -= 2; @@ -8952,7 +8952,7 @@ err_t bli_zgemm_small _mm256_storeu_pd((double *)tC, ymm8); } } - if ((m_remainder == 1)) + if (m_remainder == 1) { m_remainder -= 1; __m128d xmm0; @@ -10842,7 +10842,7 @@ err_t bli_zgemm_small_At } m_remainder = M - row_idx; - if ((m_remainder == 3)) + if (m_remainder == 3) { m_remainder -= 3; __m128d xmm0; @@ -11832,7 +11832,7 @@ err_t bli_zgemm_small_At _mm_storeu_pd((double *)(tC + 2), xmm0); } } - if ((m_remainder == 2)) + if (m_remainder == 2) { m_remainder -= 2; @@ -12615,7 +12615,7 @@ err_t bli_zgemm_small_At _mm256_storeu_pd((double *)tC, ymm8); } } - if ((m_remainder == 1)) + if (m_remainder == 1) { m_remainder -= 1; __m128d xmm0; diff --git a/kernels/zen/3/bli_gemm_sqp.c b/kernels/zen/3/bli_gemm_sqp.c deleted file mode 100644 index ceab622bf3..0000000000 --- a/kernels/zen/3/bli_gemm_sqp.c +++ /dev/null @@ -1,1203 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2021, Advanced Micro Devices, Inc. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ -#include "blis.h" -#include "immintrin.h" -#include "bli_gemm_sqp_kernels.h" - -#define SQP_THREAD_ENABLE 0//currently disabled -#define BLI_SQP_MAX_THREADS 128 -#define BLIS_LOADFIRST 0 -#define MEM_ALLOC 1//malloc performs better than bli_malloc. - -#define SET_TRANS(X,Y)\ - Y = BLIS_NO_TRANSPOSE;\ - if(bli_obj_has_trans( a ))\ - {\ - Y = BLIS_TRANSPOSE;\ - if(bli_obj_has_conj(a))\ - {\ - Y = BLIS_CONJ_TRANSPOSE;\ - }\ - }\ - else if(bli_obj_has_conj(a))\ - {\ - Y = BLIS_CONJ_NO_TRANSPOSE;\ - } - -//Macro for 3m_sqp n loop -#define BLI_SQP_ZGEMM_N(MX)\ - int j=0;\ - for(; j<=(n-nx); j+= nx)\ - {\ - status = bli_sqp_zgemm_m8( m, nx, k, a, lda, b+(j*ldb), ldb, c+(j*ldc), ldc, alpha_real, beta_real, transa, MX, p_istart, kx, &mem_3m_sqp);\ - }\ - if(jreal; - double alpha_imag = alphap->imag; - double beta_real = betap->real; - double beta_imag = betap->imag; - if( (alpha_imag!=0)||(beta_imag!=0) ) - { - return BLIS_NOT_YET_IMPLEMENTED; - } - //printf("zsqp "); - return bli_sqp_zgemm( m, n, k, ap, lda, bp, ldb, cp, ldc, alpha_real, beta_real, transa, nt); - } - else if(dt == BLIS_DOUBLE) - { - double *alpha_cast, *beta_cast; - alpha_cast = bli_obj_buffer_for_1x1(BLIS_DOUBLE, alpha); - beta_cast = bli_obj_buffer_for_1x1(BLIS_DOUBLE, beta); - - if((*beta_cast)!=1.0) - { - return BLIS_NOT_YET_IMPLEMENTED; - } - if(((*alpha_cast)!=1.0)&&((*alpha_cast)!=-1.0)) - { - return BLIS_NOT_YET_IMPLEMENTED; - } - //printf("dsqp "); - // dgemm case only transpose or no-transpose is handled. - // conjugate_transpose and conjugate no transpose are not applicable. - return bli_sqp_dgemm( m, n, k, ap, lda, bp, ldb, cp, ldc, *alpha_cast, *beta_cast, isTransA, nt); - } - - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7); - return BLIS_NOT_YET_IMPLEMENTED; -}; - -//sqp_dgemm k partition -BLIS_INLINE void bli_sqp_dgemm_kx( gint_t m, - gint_t n, - gint_t kx, - gint_t p, - double* a, - guint_t lda, - double* b, - guint_t ldb, - double* c, - guint_t ldc, - bool isTransA, - double alpha, - gint_t mx, - gint_t i, - bool pack_on, - double *aligned) -{ - inc_t j = 0; - double* ci = c + i; - double* aPacked; - //packing - if(pack_on==true) - { - aPacked = aligned; - double *pa = a + i + (p*lda); - if(isTransA==true) - { - pa = a + (i*lda) + p; - } - bli_sqp_prepackA(pa, aPacked, kx, lda, isTransA, alpha, mx); - } - else - { - aPacked = a+i + (p*lda); - } - - //compute - if(mx==8) - { - //printf("\n mx8i:%3ld ", i); - if (j <= (n - 6)) - { - j = bli_sqp_dgemm_kernel_8mx6n(n, kx, j, aPacked, lda, b + p, ldb, ci, ldc); - } - if (j <= (n - 5)) - { - j = bli_sqp_dgemm_kernel_8mx5n(n, kx, j, aPacked, lda, b + (j * ldb) + p, ldb, ci + (j * ldc), ldc); - } - if (j <= (n - 4)) - { - j = bli_sqp_dgemm_kernel_8mx4n(n, kx, j, aPacked, lda, b + (j * ldb) + p, ldb, ci + (j * ldc), ldc); - } - if (j <= (n - 3)) - { - j = bli_sqp_dgemm_kernel_8mx3n(n, kx, j, aPacked, lda, b + (j * ldb) + p, ldb, ci + (j * ldc), ldc); - } - if (j <= (n - 2)) - { - j = bli_sqp_dgemm_kernel_8mx2n(n, kx, j, aPacked, lda, b + (j * ldb) + p, ldb, ci + (j * ldc), ldc); - } - if (j <= (n - 1)) - { - j = bli_sqp_dgemm_kernel_8mx1n(n, kx, j, aPacked, lda, b + (j * ldb) + p, ldb, ci + (j * ldc), ldc); - } - } - /* mx==4 to be implemented */ - else - { - // this residue kernel needs to be improved. - j = bli_sqp_dgemm_kernel_mxn(n, kx, j, aPacked, lda, b + p, ldb, ci, ldc, mx); - } -} - -//sqp dgemm m loop -void bli_sqp_dgemm_m( gint_t i_start, - gint_t i_end, - gint_t m, - gint_t n, - gint_t k, - gint_t kx, - double* a, - guint_t lda, - double* b, - guint_t ldb, - double* c, - guint_t ldc, - bool isTransA, - double alpha, - gint_t mx, - bool pack_on, - double *aligned ) -{ -#if SQP_THREAD_ENABLE - if(pack_on==true) - { - //NEEDED IN THREADING CASE: - aligned = (double*)bli_malloc_user(sizeof(double) * kx * mx); - if(aligned==NULL) - { - return BLIS_MALLOC_RETURNED_NULL;// return to be removed - } - } -#endif//SQP_THREAD_ENABLE - - for (gint_t i = i_start; i <= (i_end-mx); i += mx) //this loop can be threaded. no of workitems = m/8 - { - int p = 0; - for(; p <= (k-kx); p += kx) - { - bli_sqp_dgemm_kx(m, n, kx, p, a, lda, b, ldb, c, ldc, isTransA, alpha, mx, i, pack_on, aligned); - }// k loop end - - if(pi_start, - arg->i_end, - arg->m, - arg->n, - arg->k, - arg->kx, - arg->a, - arg->lda, - arg->b, - arg->ldb, - arg->c, - arg->ldc, - arg->isTransA, - arg->alpha, - arg->mx, - arg->pack_on, - arg->aligned); -} - -// sqp_dgemm m loop -BLIS_INLINE err_t bli_sqp_dgemm_m8( gint_t m, - gint_t n, - gint_t k, - double* a, - guint_t lda, - double* b, - guint_t ldb, - double* c, - guint_t ldc, - bool isTransA, - double alpha, - gint_t mx, - gint_t* p_istart, - gint_t kx, - double *aligned) -{ - gint_t i; - if(kx > k) - { - kx = k; - } - - bool pack_on = false; - if((m!=mx)||(m!=lda)||isTransA) - { - pack_on = true; - } - -#if 0//SQP_THREAD_ENABLE//ENABLE Threading - gint_t status = 0; - gint_t workitems = (m-(*p_istart))/mx; - gint_t inputThreadCount = bli_thread_get_num_threads(); - inputThreadCount = bli_min(inputThreadCount, BLI_SQP_MAX_THREADS); - inputThreadCount = bli_min(inputThreadCount,workitems);// limit input thread count when workitems are lesser. - inputThreadCount = bli_max(inputThreadCount,1); - gint_t num_threads; - num_threads = bli_max(inputThreadCount,1); - gint_t mx_per_thread = workitems/num_threads;//no of workitems per thread - //printf("\nistart %d workitems %d inputThreadCount %d num_threads %d mx_per_thread %d mx %d " , - *p_istart, workitems,inputThreadCount,num_threads,mx_per_thread, mx); - - pthread_t ptid[BLI_SQP_MAX_THREADS]; - bli_sqp_thread_info thread_info[BLI_SQP_MAX_THREADS]; - - //create threads - for (gint_t t = 0; t < num_threads; t++) - { - //ptid[t].tid = t; - gint_t i_end = ((mx_per_thread*(t+1))*mx)+(*p_istart); - if(i_end>m) - { - i_end = m; - } - - if(t==(num_threads-1)) - { - if((i_end+mx)==m) - { - i_end = m; - } - - if(mx==1) - { - i_end = m; - } - } - - thread_info[t].i_start = ((mx_per_thread*t)*mx)+(*p_istart); - thread_info[t].i_end = i_end; - //printf("\n threadid %d istart %d iend %d m %d mx %d", t, thread_info[t].i_start, i_end, m, mx); - thread_info[t].m = m; - thread_info[t].n = n; - thread_info[t].k = k; - thread_info[t].kx = kx; - thread_info[t].a = a; - thread_info[t].lda = lda; - thread_info[t].b = b; - thread_info[t].ldb = ldb; - thread_info[t].c = c; - thread_info[t].ldc = ldc; - thread_info[t].isTransA = isTransA; - thread_info[t].alpha = alpha; - thread_info[t].mx = mx; - thread_info[t].pack_on = pack_on; - thread_info[t].aligned = aligned; -#if 1 - if ((status = pthread_create(&ptid[t], NULL, bli_sqp_thread, (void*)&thread_info[t]))) - { - printf("error sqp pthread_create\n"); - return BLIS_FAILURE; - } -#else - //simulate thread for debugging.. - bli_sqp_thread((void*)&thread_info[t]); -#endif - } - - //wait for completion - for (gint_t t = 0; t < num_threads; t++) - { - pthread_join(ptid[t], NULL); - } - - if(num_threads>0) - { - *p_istart = thread_info[(num_threads-1)].i_end; - } -#else//SQP_THREAD_ENABLE - - if(pack_on==true) - { - //aligned = (double*)bli_malloc_user(sizeof(double) * kx * mx); // allocation moved to top. - if(aligned==NULL) - { - return BLIS_MALLOC_RETURNED_NULL; - } - } - - for (i = (*p_istart); i <= (m-mx); i += mx) //this loop can be threaded. no of workitems = m/8 - { - int p = 0; - for(; p <= (k-kx); p += kx) - { - bli_sqp_dgemm_kx(m, n, kx, p, a, lda, b, ldb, c, ldc, - isTransA, alpha, mx, i, pack_on, aligned); - }// k loop end - - if(pdata_size * mem_req->size; - if (memSize == 0) - { - return -1; - } - memSize += 128;// extra 128 bytes added for alignment. Could be minimized to 64. -#if MEM_ALLOC -#ifdef BLIS_ENABLE_MEM_TRACING - printf( "malloc(): size %ld\n",( long )memSize); - fflush( stdout ); -#endif - mem_req->unalignedBuf = (double*)malloc(memSize); - if (mem_req->unalignedBuf == NULL) - { - return -1; - } - - int64_t address = (int64_t)mem_req->unalignedBuf; - address += (-address) & 63; //64 bytes alignment done. - mem_req->alignedBuf = (double*)address; -#else - mem_req->alignedBuf = bli_malloc_user( memSize ); - if (mem_req->alignedBuf == NULL) - { - return -1; - } -#endif - return 0; -} - -gint_t bli_allocateWorkspace(gint_t n, gint_t k, mem_block *mxr, mem_block *mxi, mem_block *msx) -{ - //allocate workspace - mxr->data_size = mxi->data_size = msx->data_size = sizeof(double); - mxr->size = mxi->size = n * k; - msx->size = n * k; - mxr->alignedBuf = mxi->alignedBuf = msx->alignedBuf = NULL; - mxr->unalignedBuf = mxi->unalignedBuf = msx->unalignedBuf = NULL; - - if (!((bli_getaligned(mxr) == 0) && (bli_getaligned(mxi) == 0) && (bli_getaligned(msx) == 0))) - { -#if MEM_ALLOC - if(mxr->unalignedBuf) - { - free(mxr->unalignedBuf); - } - if(mxi->unalignedBuf) - { - free(mxi->unalignedBuf); - } - if(msx->unalignedBuf) - { - free(msx->unalignedBuf); - } -#else - bli_free_user(mxr->alignedBuf); - bli_free_user(mxi->alignedBuf); - bli_free_user(msx->alignedBuf); -#endif - return -1; - } - return 0; -} - -//3m_sqp k loop -BLIS_INLINE void bli_sqp_zgemm_kx( gint_t m, - gint_t n, - gint_t kx, - gint_t p, - double* a, - guint_t lda, - guint_t ldb, - double* c, - guint_t ldc, - trans_t transa, - double alpha, - double beta, - gint_t mx, - gint_t i, - double* ar, - double* ai, - double* as, - double* br, - double* bi, - double* bs, - double* cr, - double* ci, - double* w, - double *a_aligned) -{ - gint_t j; - - ////////////// operation 1 ///////////////// - /* Split a (ar, ai) and - compute as = ar + ai */ - double* par = ar; - double* pai = ai; - double* pas = as; - - /* a matrix real and imag packing and compute. */ - bli_3m_sqp_packA_real_imag_sum(a, i, kx+p, lda, par, pai, pas, transa, mx, p); - - double* pcr = cr; - double* pci = ci; - - //Split Cr and Ci and beta multiplication done. - double* pc = c + i; - if(p==0) - { - bli_3m_sqp_packC_real_imag(pc, n, mx, ldc, pcr, pci, beta, mx); - } - //Ci := rgemm( SA, SB, Ci ) - gint_t istart = 0; - gint_t* p_is = &istart; - *p_is = 0; - bli_sqp_dgemm_m8(mx, n, kx, as, mx, bs, ldb, ci, mx, false, 1.0, mx, p_is, kx, a_aligned); - - ////////////// operation 2 ///////////////// - //Wr: = dgemm_sqp(Ar, Br, 0) // Wr output 8xn - double* wr = w; - for (j = 0; j < n; j++) { - for (gint_t ii = 0; ii < mx; ii += 1) { - *wr = 0; - wr++; - } - } - wr = w; - - *p_is = 0; - bli_sqp_dgemm_m8(mx, n, kx, ar, mx, br, ldb, wr, mx, false, 1.0, mx, p_is, kx, a_aligned); - //Cr : = addm(Wr, Cr) - bli_add_m(mx, n, wr, cr); - //Ci : = subm(Wr, Ci) - bli_sub_m(mx, n, wr, ci); - - - ////////////// operation 3 ///////////////// - //Wi : = dgemm_sqp(Ai, Bi, 0) // Wi output 8xn - double* wi = w; - for (j = 0; j < n; j++) { - for (gint_t ii = 0; ii < mx; ii += 1) { - *wi = 0; - wi++; - } - } - wi = w; - - *p_is = 0; - bli_sqp_dgemm_m8(mx, n, kx, ai, mx, bi, ldb, wi, mx, false, 1.0, mx, p_is, kx, a_aligned); - //Cr : = subm(Wi, Cr) - bli_sub_m(mx, n, wi, cr); - //Ci : = subm(Wi, Ci) - bli_sub_m(mx, n, wi, ci); - - pcr = cr; - pci = ci; - - for (j = 0; j < n; j++) - { - for (gint_t ii = 0; ii < (mx*2); ii += 2) - { - c[(j * ldc) + i + ii] = *pcr; - c[(j * ldc) + i + ii + 1] = *pci; - pcr++; pci++; - } - } -} - -/**************************************************************/ -/* workspace memory allocation for 3m_sqp algorithm for zgemm */ -/**************************************************************/ -err_t allocate_3m_Sqp_workspace(workspace_3m_sqp *mem_3m_sqp, - gint_t mx, - gint_t nx, - gint_t k, - gint_t kx ) -{ - //3m_sqp workspace Memory allocation - /* B matrix */ - // B matrix packed with n x k size. without kx smaller sizes for now. - mem_block mbr, mbi, mbs; - if(bli_allocateWorkspace(nx, k, &mbr, &mbi, &mbs)!=0) - { - return BLIS_FAILURE; - } - mem_3m_sqp->br = (double*)mbr.alignedBuf; - mem_3m_sqp->bi = (double*)mbi.alignedBuf; - mem_3m_sqp->bs = (double*)mbs.alignedBuf; - mem_3m_sqp->br_unaligned = (double*)mbr.unalignedBuf; - mem_3m_sqp->bi_unaligned = (double*)mbi.unalignedBuf; - mem_3m_sqp->bs_unaligned = (double*)mbs.unalignedBuf; - - /* Workspace memory allocation currently done dynamically - This needs to be taken from already allocated memory pool in application for better performance */ - /* A matrix */ - mem_block mar, mai, mas; - if(bli_allocateWorkspace(mx, kx, &mar, &mai, &mas) !=0) - { - return BLIS_FAILURE; - } - mem_3m_sqp->ar = (double*)mar.alignedBuf; - mem_3m_sqp->ai = (double*)mai.alignedBuf; - mem_3m_sqp->as = (double*)mas.alignedBuf; - mem_3m_sqp->ar_unaligned = (double*)mar.unalignedBuf; - mem_3m_sqp->ai_unaligned = (double*)mai.unalignedBuf; - mem_3m_sqp->as_unaligned = (double*)mas.unalignedBuf; - - /* w matrix */ - mem_block mw; - mw.data_size = sizeof(double); - mw.size = mx * nx; - if (bli_getaligned(&mw) != 0) - { - return BLIS_FAILURE; - } - mem_3m_sqp->w = (double*)mw.alignedBuf; - mem_3m_sqp->w_unaligned = (double*)mw.unalignedBuf; - /* cr matrix */ - mem_block mcr; - mcr.data_size = sizeof(double); - mcr.size = mx * nx; - if (bli_getaligned(&mcr) != 0) - { - return BLIS_FAILURE; - } - mem_3m_sqp->cr = (double*)mcr.alignedBuf; - mem_3m_sqp->cr_unaligned = (double*)mcr.unalignedBuf; - - - /* ci matrix */ - mem_block mci; - mci.data_size = sizeof(double); - mci.size = mx * nx; - if (bli_getaligned(&mci) != 0) - { - return BLIS_FAILURE; - } - mem_3m_sqp->ci = (double*)mci.alignedBuf; - mem_3m_sqp->ci_unaligned = (double*)mci.unalignedBuf; - - // A packing buffer - mem_3m_sqp->aPacked = (double*)bli_malloc_user(sizeof(double) * kx * mx); - if (mem_3m_sqp->aPacked == NULL) - { - return BLIS_FAILURE; - } - - return BLIS_SUCCESS; -} - -void free_3m_Sqp_workspace(workspace_3m_sqp *mem_3m_sqp) -{ - // A packing buffer free - bli_free_user(mem_3m_sqp->aPacked); - -#if MEM_ALLOC - if(mem_3m_sqp->ar_unaligned) - { - free(mem_3m_sqp->ar_unaligned); - } - if(mem_3m_sqp->ai_unaligned) - { - free(mem_3m_sqp->ai_unaligned); - } - if(mem_3m_sqp->as_unaligned) - { - free(mem_3m_sqp->as_unaligned); - } - - if(mem_3m_sqp->br_unaligned) - { - free(mem_3m_sqp->br_unaligned); - } - if(mem_3m_sqp->bi_unaligned) - { - free(mem_3m_sqp->bi_unaligned); - } - if(mem_3m_sqp->bs_unaligned) - { - free(mem_3m_sqp->bs_unaligned); - } - - if(mem_3m_sqp->w_unaligned) - { - free(mem_3m_sqp->w_unaligned); - } - if(mem_3m_sqp->cr_unaligned) - { - free(mem_3m_sqp->cr_unaligned); - } - if(mem_3m_sqp->ci_unaligned) - { - free(mem_3m_sqp->ci_unaligned); - } - -#else//MEM_ALLOC - /* free workspace buffers */ - bli_free_user(mem_3m_sqp->br); - bli_free_user(mem_3m_sqp->bi); - bli_free_user(mem_3m_sqp->bs); - bli_free_user(mem_3m_sqp->ar); - bli_free_user(mem_3m_sqp->ai); - bli_free_user(mem_3m_sqp->as); - bli_free_user(mem_3m_sqp->w); - bli_free_user(mem_3m_sqp->cr); - bli_free_user(mem_3m_sqp->ci); -#endif//MEM_ALLOC -} - -//3m_sqp m loop -BLIS_INLINE err_t bli_sqp_zgemm_m8( gint_t m, - gint_t n, - gint_t k, - double* a, - guint_t lda, - double* b, - guint_t ldb, - double* c, - guint_t ldc, - double alpha, - double beta, - trans_t transa, - gint_t mx, - gint_t* p_istart, - gint_t kx, - workspace_3m_sqp *mem_3m_sqp) -{ - inc_t m2 = m<<1; - inc_t mxmul2 = mx<<1; - - if((*p_istart) > (m2-mxmul2)) - { - return BLIS_SUCCESS; - } - inc_t i; - gint_t max_m = (m2-mxmul2); - - //get workspace - double* ar, * ai, * as; - ar = mem_3m_sqp->ar; - ai = mem_3m_sqp->ai; - as = mem_3m_sqp->as; - - double* br, * bi, * bs; - br = mem_3m_sqp->br; - bi = mem_3m_sqp->bi; - bs = mem_3m_sqp->bs; - - double* cr, * ci; - cr = mem_3m_sqp->cr; - ci = mem_3m_sqp->ci; - - double *w; - w = mem_3m_sqp->w; - - double* a_aligned; - a_aligned = mem_3m_sqp->aPacked; - - /* Split b (br, bi) and - compute bs = br + bi */ - double* pbr = br; - double* pbi = bi; - double* pbs = bs; - /* b matrix real and imag packing and compute. */ - bli_3m_sqp_packB_real_imag_sum(b, n, k, ldb, pbr, pbi, pbs, alpha, mx); - - for (i = (*p_istart); i <= max_m; i += mxmul2) //this loop can be threaded. - { -#if KLP//kloop - int p = 0; - for(; p <= (k-kx); p += kx) - { - bli_sqp_zgemm_kx(m, n, kx, p, a, lda, k, c, ldc, - transa, alpha, beta, mx, i, ar, ai, as, - br + p, bi + p, bs + p, cr, ci, w, a_aligned); - }// k loop end - - if(p>3)<<3); - - workspace_3m_sqp mem_3m_sqp; - - /* multiply lda, ldb and ldc by 2 to account for - real & imaginary components per dcomplex. */ - lda = lda * 2; - ldb = ldb * 2; - ldc = ldc * 2; - - /* user can set BLIS_MULTI_INSTANCE macro for - better performance while runing multi-instance use-case. - */ - dim_t multi_instance = bli_env_get_var( "BLIS_MULTI_INSTANCE", -1 ); - gint_t nx = n; - if(multi_instance>0) - { - //limited nx size helps in reducing memory footprint in multi-instance case. - nx = 84; - // 84 is derived based on tuning results - } - - if(nx>n) - { - nx = n; - } - - gint_t kx = k;// kx is configurable at run-time. -#if KLP - if (kx > k) - { - kx = k; - } - // for tn case there is a bug in handling k parts. To be fixed. - if(transa!=BLIS_NO_TRANSPOSE) - { - kx = k; - } -#else - kx = k; -#endif - //3m_sqp workspace Memory allocation - if(allocate_3m_Sqp_workspace(&mem_3m_sqp, mx, nx, k, kx)!=BLIS_SUCCESS) - { - return BLIS_FAILURE; - } - - BLI_SQP_ZGEMM_N(mx) - *p_istart = (m-m8rem)*2; - - if(m8rem!=0) - { - //complete residue m blocks - BLI_SQP_ZGEMM_N(m8rem) - } - - free_3m_Sqp_workspace(&mem_3m_sqp); - return status; -} - -/****************************************************************************/ -/*********************** dgemm_sqp implementation****************************/ -/****************************************************************************/ -/* dgemm_sqp implementation packs A matrix based on lda and m size. - dgemm_sqp focuses mainly on square matrixes but also supports non-square matrix. - Current support is limiteed to m multiple of 8 and column storage. - C = AxB and C = AtxB is handled in the design. - AtxB case is done by transposing A matrix while packing A. - In majority of use-case, alpha are +/-1, so instead of explicitly multiplying - alpha its done during packing itself by changing sign. -*/ -BLIS_INLINE err_t bli_sqp_dgemm(gint_t m, - gint_t n, - gint_t k, - double* a, - guint_t lda, - double* b, - guint_t ldb, - double* c, - guint_t ldc, - double alpha, - double beta, - bool isTransA, - dim_t nt) -{ - gint_t istart = 0; - gint_t* p_istart = &istart; - *p_istart = 0; - err_t status = BLIS_SUCCESS; - dim_t m8rem = m - ((m>>3)<<3); - - /* dgemm implementation with 8mx5n major kernel and column preferred storage */ - gint_t mx = 8; - gint_t kx = k; - double* a_aligned = NULL; - - if(nt<=1)//single pack buffer allocated for single thread case - { - a_aligned = (double*)bli_malloc_user(sizeof(double) * kx * mx); - } - - gint_t nx = n;//MAX; - if(nx>n) - { - nx = n; - } - - //mx==8 case for dgemm. - BLI_SQP_DGEMM_N(mx) - *p_istart = (m-m8rem); - - if(nt>1) - { - //2nd level thread for mx=8 - gint_t rem_m = m - (*p_istart); - if((rem_m>=mx)&&(status==BLIS_SUCCESS)) - { - status = bli_sqp_dgemm_m8( m, n, k, a, lda, b, ldb, c, ldc, - isTransA, alpha, mx, p_istart, kx, a_aligned); - } - } - - if(status==BLIS_SUCCESS) - { - if(m8rem!=0) - { - //complete residue m blocks - BLI_SQP_DGEMM_N(m8rem) - } - } - - if(nt<=1)//single pack buffer allocated for single thread case - { - bli_free_user(a_aligned); - } - return status; -} \ No newline at end of file diff --git a/kernels/zen/3/bli_gemm_sqp_kernels.c b/kernels/zen/3/bli_gemm_sqp_kernels.c deleted file mode 100644 index 0f20c0a956..0000000000 --- a/kernels/zen/3/bli_gemm_sqp_kernels.c +++ /dev/null @@ -1,1750 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2021, Advanced Micro Devices, Inc. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ -#include "blis.h" -#include "immintrin.h" -#include "bli_gemm_sqp_kernels.h" - -#define BLIS_LOADFIRST 0 -#define BLIS_ENABLE_PREFETCH 1 - -#define BLIS_MX8 8 -#define BLIS_MX4 4 -#define BLIS_MX1 1 - -/****************************************************************************/ -/*************** dgemm kernels (8mxn) column preffered *********************/ -/****************************************************************************/ - -/* Main dgemm kernel 8mx6n with single load and store of C matrix block - alpha = +/-1 and beta = +/-1,0 handled while packing.*/ -inc_t bli_sqp_dgemm_kernel_8mx6n(gint_t n, - gint_t k, - gint_t j, - double* aPacked, - guint_t lda, - double* b, - guint_t ldb, - double* c, - guint_t ldc) -{ - gint_t p; - - __m256d av0, av1; - __m256d bv0, bv1; - __m256d cv0, cv1, cv2, cv3, cv4, cv5; - __m256d cx0, cx1, cx2, cx3, cx4, cx5; - double* pb, * pc; - - pb = b; - pc = c; - inc_t ldc6 = ldc * 6; inc_t ldb6 = ldb * 6; - - for (j = 0; j <= (n - 6); j += 6) { - double* pcldc = pc + ldc; - double* pcldc2 = pcldc + ldc; - double* pcldc3 = pcldc2 + ldc; - double* pcldc4 = pcldc3 + ldc; - double* pcldc5 = pcldc4 + ldc; - - double* pbldb = pb + ldb; - double* pbldb2 = pbldb + ldb; - double* pbldb3 = pbldb2 + ldb; - double* pbldb4 = pbldb3 + ldb; - double* pbldb5 = pbldb4 + ldb; - -#if BLIS_ENABLE_PREFETCH - _mm_prefetch((char*)(pc), _MM_HINT_T0); - _mm_prefetch((char*)(pcldc), _MM_HINT_T0); - _mm_prefetch((char*)(pcldc2), _MM_HINT_T0); - _mm_prefetch((char*)(pcldc3), _MM_HINT_T0); - _mm_prefetch((char*)(pcldc4), _MM_HINT_T0); - _mm_prefetch((char*)(pcldc5), _MM_HINT_T0); - - _mm_prefetch((char*)(aPacked), _MM_HINT_T0); - - _mm_prefetch((char*)(pb), _MM_HINT_T0); - _mm_prefetch((char*)(pbldb), _MM_HINT_T0); - _mm_prefetch((char*)(pbldb2), _MM_HINT_T0); - _mm_prefetch((char*)(pbldb3), _MM_HINT_T0); - _mm_prefetch((char*)(pbldb4), _MM_HINT_T0); - _mm_prefetch((char*)(pbldb5), _MM_HINT_T0); -#endif - /* C matrix column major load */ -#if BLIS_LOADFIRST - cv0 = _mm256_loadu_pd(pc); cx0 = _mm256_loadu_pd(pc + 4); - cv1 = _mm256_loadu_pd(pcldc); cx1 = _mm256_loadu_pd(pcldc + 4); - cv2 = _mm256_loadu_pd(pcldc2); cx2 = _mm256_loadu_pd(pcldc2 + 4); - cv3 = _mm256_loadu_pd(pcldc3); cx3 = _mm256_loadu_pd(pcldc3 + 4); - cv4 = _mm256_loadu_pd(pcldc4); cx4 = _mm256_loadu_pd(pcldc4 + 4); - cv5 = _mm256_loadu_pd(pcldc5); cx5 = _mm256_loadu_pd(pcldc5 + 4); -#else - cv0 = _mm256_setzero_pd(); cx0 = _mm256_setzero_pd(); - cv1 = _mm256_setzero_pd(); cx1 = _mm256_setzero_pd(); - cv2 = _mm256_setzero_pd(); cx2 = _mm256_setzero_pd(); - cv3 = _mm256_setzero_pd(); cx3 = _mm256_setzero_pd(); - cv4 = _mm256_setzero_pd(); cx4 = _mm256_setzero_pd(); - cv5 = _mm256_setzero_pd(); cx5 = _mm256_setzero_pd(); -#endif - double* x = aPacked; - double* pb0 = pb; - for (p = 0; p < k; p += 1) { - av0 = _mm256_loadu_pd(x); x += 4; av1 = _mm256_loadu_pd(x); x += 4; - bv0 = _mm256_broadcast_sd (pb0); pb0++; - bv1 = _mm256_broadcast_sd(pbldb); pbldb++; - cv0 = _mm256_fmadd_pd(av0, bv0, cv0); - cx0 = _mm256_fmadd_pd(av1, bv0, cx0); - cv1 = _mm256_fmadd_pd(av0, bv1, cv1); - cx1 = _mm256_fmadd_pd(av1, bv1, cx1); - - bv0 = _mm256_broadcast_sd(pbldb2);pbldb2++; - bv1 = _mm256_broadcast_sd(pbldb3);pbldb3++; - cv2 = _mm256_fmadd_pd(av0, bv0, cv2); - cx2 = _mm256_fmadd_pd(av1, bv0, cx2); - cv3 = _mm256_fmadd_pd(av0, bv1, cv3); - cx3 = _mm256_fmadd_pd(av1, bv1, cx3); - - bv0 = _mm256_broadcast_sd(pbldb4);pbldb4++; - bv1 = _mm256_broadcast_sd(pbldb5);pbldb5++; - cv4 = _mm256_fmadd_pd(av0, bv0, cv4); - cx4 = _mm256_fmadd_pd(av1, bv0, cx4); - cv5 = _mm256_fmadd_pd(av0, bv1, cv5); - cx5 = _mm256_fmadd_pd(av1, bv1, cx5); - } -#if BLIS_LOADFIRST -#else - bv0 = _mm256_loadu_pd(pc); bv1 = _mm256_loadu_pd(pc + 4); - cv0 = _mm256_add_pd(cv0, bv0); cx0 = _mm256_add_pd(cx0, bv1); - - av0 = _mm256_loadu_pd(pcldc); av1 = _mm256_loadu_pd(pcldc + 4); - cv1 = _mm256_add_pd(cv1, av0); cx1 = _mm256_add_pd(cx1, av1); - - bv0 = _mm256_loadu_pd(pcldc2); bv1 = _mm256_loadu_pd(pcldc2 + 4); - cv2 = _mm256_add_pd(cv2, bv0); cx2 = _mm256_add_pd(cx2, bv1); - - av0 = _mm256_loadu_pd(pcldc3); av1 = _mm256_loadu_pd(pcldc3 + 4); - cv3 = _mm256_add_pd(cv3, av0); cx3 = _mm256_add_pd(cx3, av1); - - bv0 = _mm256_loadu_pd(pcldc4); bv1 = _mm256_loadu_pd(pcldc4 + 4); - cv4 = _mm256_add_pd(cv4, bv0); cx4 = _mm256_add_pd(cx4, bv1); - - av0 = _mm256_loadu_pd(pcldc5); av1 = _mm256_loadu_pd(pcldc5 + 4); - cv5 = _mm256_add_pd(cv5, av0); cx5 = _mm256_add_pd(cx5, av1); -#endif - /* C matrix column major store */ - _mm256_storeu_pd(pc, cv0); - _mm256_storeu_pd(pc + 4, cx0); - - _mm256_storeu_pd(pcldc, cv1); - _mm256_storeu_pd(pcldc + 4, cx1); - - _mm256_storeu_pd(pcldc2, cv2); - _mm256_storeu_pd(pcldc2 + 4, cx2); - - _mm256_storeu_pd(pcldc3, cv3); - _mm256_storeu_pd(pcldc3 + 4, cx3); - - _mm256_storeu_pd(pcldc4, cv4); - _mm256_storeu_pd(pcldc4 + 4, cx4); - - _mm256_storeu_pd(pcldc5, cv5); - _mm256_storeu_pd(pcldc5 + 4, cx5); - - pc += ldc6;pb += ldb6; - } - //printf(" 8x6:j:%d ", j); - return j; -} - -/* alternative Main dgemm kernel 8mx5n with single load and store of C matrix block - alpha = +/-1 and beta = +/-1,0 handled while packing.*/ -inc_t bli_sqp_dgemm_kernel_8mx5n( gint_t n, - gint_t k, - gint_t j, - double* aPacked, - guint_t lda, - double* b, - guint_t ldb, - double* c, - guint_t ldc) -{ - gint_t p; - __m256d av0; - __m256d bv0, bv1, bv2, bv3; - __m256d cv0, cv1, cv2, cv3; - __m256d cx0, cx1, cx2, cx3; - __m256d bv4, cv4, cx4; - double* pb, * pc; - - pb = b; - pc = c; - inc_t ldc5 = ldc * 5; inc_t ldb5 = ldb * 5; - - for (; j <= (n - 5); j += 5) { - - double* pcldc = pc + ldc; - double* pcldc2 = pcldc + ldc; - double* pcldc3 = pcldc2 + ldc; - double* pcldc4 = pcldc3 + ldc; - - double* pbldb = pb + ldb; - double* pbldb2 = pbldb + ldb; - double* pbldb3 = pbldb2 + ldb; - double* pbldb4 = pbldb3 + ldb; - -#if BLIS_ENABLE_PREFETCH - _mm_prefetch((char*)(pc), _MM_HINT_T0); - _mm_prefetch((char*)(pcldc), _MM_HINT_T0); - _mm_prefetch((char*)(pcldc2), _MM_HINT_T0); - _mm_prefetch((char*)(pcldc3), _MM_HINT_T0); - _mm_prefetch((char*)(pcldc4), _MM_HINT_T0); - - _mm_prefetch((char*)(aPacked), _MM_HINT_T0); - - _mm_prefetch((char*)(pb), _MM_HINT_T0); - _mm_prefetch((char*)(pbldb), _MM_HINT_T0); - _mm_prefetch((char*)(pbldb2), _MM_HINT_T0); - _mm_prefetch((char*)(pbldb3), _MM_HINT_T0); - _mm_prefetch((char*)(pbldb4), _MM_HINT_T0); -#endif - /* C matrix column major load */ -#if BLIS_LOADFIRST - cv0 = _mm256_loadu_pd(pc); cx0 = _mm256_loadu_pd(pc + 4); - cv1 = _mm256_loadu_pd(pcldc); cx1 = _mm256_loadu_pd(pcldc + 4); - cv2 = _mm256_loadu_pd(pcldc2); cx2 = _mm256_loadu_pd(pcldc2 + 4); - cv3 = _mm256_loadu_pd(pcldc3); cx3 = _mm256_loadu_pd(pcldc3 + 4); - cv4 = _mm256_loadu_pd(pcldc4); cx4 = _mm256_loadu_pd(pcldc4 + 4); -#else - cv0 = _mm256_setzero_pd(); cx0 = _mm256_setzero_pd(); - cv1 = _mm256_setzero_pd(); cx1 = _mm256_setzero_pd(); - cv2 = _mm256_setzero_pd(); cx2 = _mm256_setzero_pd(); - cv3 = _mm256_setzero_pd(); cx3 = _mm256_setzero_pd(); - cv4 = _mm256_setzero_pd(); cx4 = _mm256_setzero_pd(); -#endif - double* x = aPacked; - double* pb0 = pb; - for (p = 0; p < k; p += 1) { - bv0 = _mm256_broadcast_sd(pb0); pb0++; - bv1 = _mm256_broadcast_sd(pbldb); pbldb++; - bv2 = _mm256_broadcast_sd(pbldb2); pbldb2++; - bv3 = _mm256_broadcast_sd(pbldb3);pbldb3++; - bv4 = _mm256_broadcast_sd(pbldb4);pbldb4++; - - av0 = _mm256_loadu_pd(x); x += 4; - cv0 = _mm256_fmadd_pd(av0, bv0, cv0); - cv1 = _mm256_fmadd_pd(av0, bv1, cv1); - cv2 = _mm256_fmadd_pd(av0, bv2, cv2); - cv3 = _mm256_fmadd_pd(av0, bv3, cv3); - cv4 = _mm256_fmadd_pd(av0, bv4, cv4); - - av0 = _mm256_loadu_pd(x); x += 4; - cx0 = _mm256_fmadd_pd(av0, bv0, cx0); - cx1 = _mm256_fmadd_pd(av0, bv1, cx1); - cx2 = _mm256_fmadd_pd(av0, bv2, cx2); - cx3 = _mm256_fmadd_pd(av0, bv3, cx3); - cx4 = _mm256_fmadd_pd(av0, bv4, cx4); - } -#if BLIS_LOADFIRST -#else - bv0 = _mm256_loadu_pd(pc); bv1 = _mm256_loadu_pd(pc + 4); - cv0 = _mm256_add_pd(cv0, bv0); cx0 = _mm256_add_pd(cx0, bv1); - - bv2 = _mm256_loadu_pd(pcldc); bv3 = _mm256_loadu_pd(pcldc + 4); - cv1 = _mm256_add_pd(cv1, bv2); cx1 = _mm256_add_pd(cx1, bv3); - - bv0 = _mm256_loadu_pd(pcldc2); bv1 = _mm256_loadu_pd(pcldc2 + 4); - cv2 = _mm256_add_pd(cv2, bv0); cx2 = _mm256_add_pd(cx2, bv1); - - bv2 = _mm256_loadu_pd(pcldc3); bv3 = _mm256_loadu_pd(pcldc3 + 4); - cv3 = _mm256_add_pd(cv3, bv2); cx3 = _mm256_add_pd(cx3, bv3); - - bv0 = _mm256_loadu_pd(pcldc4); bv1 = _mm256_loadu_pd(pcldc4 + 4); - cv4 = _mm256_add_pd(cv4, bv0); cx4 = _mm256_add_pd(cx4, bv1); -#endif - /* C matrix column major store */ - _mm256_storeu_pd(pc, cv0); - _mm256_storeu_pd(pc + 4, cx0); - - _mm256_storeu_pd(pcldc, cv1); - _mm256_storeu_pd(pcldc + 4, cx1); - - _mm256_storeu_pd(pcldc2, cv2); - _mm256_storeu_pd(pcldc2 + 4, cx2); - - _mm256_storeu_pd(pcldc3, cv3); - _mm256_storeu_pd(pcldc3 + 4, cx3); - - _mm256_storeu_pd(pcldc4, cv4); - _mm256_storeu_pd(pcldc4 + 4, cx4); - - pc += ldc5;pb += ldb5; - } - //printf(" 8x5:j:%d ", j); - return j; -} - -/* residue dgemm kernel 8mx4n with single load and store of C matrix block - Code could be optimized further, complete ymm register set is not used. - Being residue kernel, its of lesser priority. -*/ -inc_t bli_sqp_dgemm_kernel_8mx4n( gint_t n, - gint_t k, - gint_t j, - double* aPacked, - guint_t lda, - double* b, - guint_t ldb, - double* c, - guint_t ldc) -{ - gint_t p; - __m256d av0; - __m256d bv0, bv1, bv2, bv3; - __m256d cv0, cv1, cv2, cv3; - __m256d cx0, cx1, cx2, cx3; - double* pb, * pc; - - pb = b; - pc = c; - inc_t ldc4 = ldc * 4; inc_t ldb4 = ldb * 4; - - for (; j <= (n - 4); j += 4) { - - double* pcldc = pc + ldc; double* pcldc2 = pcldc + ldc; double* pcldc3 = pcldc2 + ldc; - double* pbldb = pb + ldb; double* pbldb2 = pbldb + ldb; double* pbldb3 = pbldb2 + ldb; - - cv0 = _mm256_loadu_pd(pc); cx0 = _mm256_loadu_pd(pc + 4); - cv1 = _mm256_loadu_pd(pcldc); cx1 = _mm256_loadu_pd(pcldc + 4); - cv2 = _mm256_loadu_pd(pcldc2); cx2 = _mm256_loadu_pd(pcldc2 + 4); - cv3 = _mm256_loadu_pd(pcldc3); cx3 = _mm256_loadu_pd(pcldc3 + 4); - { - double* x = aPacked; - double* pb0 = pb; - for (p = 0; p < k; p += 1) { - // better kernel to be written since more register are available. - bv0 = _mm256_broadcast_sd(pb0); pb0++; - bv1 = _mm256_broadcast_sd(pbldb); pbldb++; - bv2 = _mm256_broadcast_sd(pbldb2); pbldb2++; - bv3 = _mm256_broadcast_sd(pbldb3); pbldb3++; - - av0 = _mm256_loadu_pd(x); x += 4; - cv0 = _mm256_fmadd_pd(av0, bv0, cv0); - cv1 = _mm256_fmadd_pd(av0, bv1, cv1); - cv2 = _mm256_fmadd_pd(av0, bv2, cv2); - cv3 = _mm256_fmadd_pd(av0, bv3, cv3); - - av0 = _mm256_loadu_pd(x); x += 4; - cx0 = _mm256_fmadd_pd(av0, bv0, cx0); - cx1 = _mm256_fmadd_pd(av0, bv1, cx1); - cx2 = _mm256_fmadd_pd(av0, bv2, cx2); - cx3 = _mm256_fmadd_pd(av0, bv3, cx3); - } - } - _mm256_storeu_pd(pc, cv0); - _mm256_storeu_pd(pc + 4, cx0); - _mm256_storeu_pd(pcldc, cv1); - _mm256_storeu_pd(pcldc + 4, cx1); - _mm256_storeu_pd(pcldc2, cv2); - _mm256_storeu_pd(pcldc2 + 4, cx2); - _mm256_storeu_pd(pcldc3, cv3); - _mm256_storeu_pd(pcldc3 + 4, cx3); - - pc += ldc4;pb += ldb4; - }// j loop 4 multiple - //printf(" 8x4:j:%d ", j); - return j; -} - -/* residue dgemm kernel 8mx3n with single load and store of C matrix block - Code could be optimized further, complete ymm register set is not used. - Being residue kernel, its of lesser priority. -*/ -inc_t bli_sqp_dgemm_kernel_8mx3n( gint_t n, - gint_t k, - gint_t j, - double* aPacked, - guint_t lda, - double* b, - guint_t ldb, - double* c, - guint_t ldc) -{ - gint_t p; - __m256d av0; - __m256d bv0, bv1, bv2; - __m256d cv0, cv1, cv2; - __m256d cx0, cx1, cx2; - double* pb, * pc; - - pb = b; - pc = c; - - inc_t ldc3 = ldc * 3; inc_t ldb3 = ldb * 3; - - for (; j <= (n - 3); j += 3) { - - double* pcldc = pc + ldc; double* pcldc2 = pcldc + ldc; - double* pbldb = pb + ldb; double* pbldb2 = pbldb + ldb; - - cv0 = _mm256_loadu_pd(pc); cx0 = _mm256_loadu_pd(pc + 4); - cv1 = _mm256_loadu_pd(pcldc); cx1 = _mm256_loadu_pd(pcldc + 4); - cv2 = _mm256_loadu_pd(pcldc2); cx2 = _mm256_loadu_pd(pcldc2 + 4); - { - double* x = aPacked; - double* pb0 = pb; - for (p = 0; p < k; p += 1) { - bv0 = _mm256_broadcast_sd(pb0); pb0++; - bv1 = _mm256_broadcast_sd(pbldb); pbldb++; - bv2 = _mm256_broadcast_sd(pbldb2); pbldb2++; - - av0 = _mm256_loadu_pd(x); x += 4; - cv0 = _mm256_fmadd_pd(av0, bv0, cv0); - cv1 = _mm256_fmadd_pd(av0, bv1, cv1); - cv2 = _mm256_fmadd_pd(av0, bv2, cv2); - - av0 = _mm256_loadu_pd(x); x += 4; - cx0 = _mm256_fmadd_pd(av0, bv0, cx0); - cx1 = _mm256_fmadd_pd(av0, bv1, cx1); - cx2 = _mm256_fmadd_pd(av0, bv2, cx2); - } - } - - _mm256_storeu_pd(pc, cv0); - _mm256_storeu_pd(pc + 4, cx0); - _mm256_storeu_pd(pcldc, cv1); - _mm256_storeu_pd(pcldc + 4, cx1); - _mm256_storeu_pd(pcldc2, cv2); - _mm256_storeu_pd(pcldc2 + 4, cx2); - - pc += ldc3;pb += ldb3; - }// j loop 3 multiple - //printf(" 8x3:j:%d ", j); - return j; -} - -/* residue dgemm kernel 8mx2n with single load and store of C matrix block - Code could be optimized further, complete ymm register set is not used. - Being residue kernel, its of lesser priority. -*/ -inc_t bli_sqp_dgemm_kernel_8mx2n( gint_t n, - gint_t k, - gint_t j, - double* aPacked, - guint_t lda, - double* b, - guint_t ldb, - double* c, - guint_t ldc) -{ - gint_t p; - __m256d av0; - __m256d bv0, bv1; - __m256d cv0, cv1; - __m256d cx0, cx1; - double* pb, * pc; - - pb = b; - pc = c; - inc_t ldc2 = ldc * 2; inc_t ldb2 = ldb * 2; - - for (; j <= (n - 2); j += 2) { - double* pcldc = pc + ldc; - double* pbldb = pb + ldb; - - cv0 = _mm256_loadu_pd(pc); cx0 = _mm256_loadu_pd(pc + 4); - cv1 = _mm256_loadu_pd(pcldc); cx1 = _mm256_loadu_pd(pcldc + 4); - { - double* x = aPacked; - double* pb0 = pb; - for (p = 0; p < k; p += 1) { - bv0 = _mm256_broadcast_sd(pb0); pb0++; - bv1 = _mm256_broadcast_sd(pbldb); pbldb++; - - av0 = _mm256_loadu_pd(x); x += 4; - cv0 = _mm256_fmadd_pd(av0, bv0, cv0); - cv1 = _mm256_fmadd_pd(av0, bv1, cv1); - - av0 = _mm256_loadu_pd(x); x += 4; - cx0 = _mm256_fmadd_pd(av0, bv0, cx0); - cx1 = _mm256_fmadd_pd(av0, bv1, cx1); - } - } - _mm256_storeu_pd(pc, cv0); - _mm256_storeu_pd(pc + 4, cx0); - _mm256_storeu_pd(pcldc, cv1); - _mm256_storeu_pd(pcldc + 4, cx1); - - pc += ldc2;pb += ldb2; - }// j loop 2 multiple - //printf(" 8x2:j:%d ", j); - return j; -} - -/* residue dgemm kernel 8mx1n with single load and store of C matrix block - Code could be optimized further, complete ymm register set is not used. - Being residue kernel, its of lesser priority. -*/ -inc_t bli_sqp_dgemm_kernel_8mx1n( gint_t n, - gint_t k, - gint_t j, - double* aPacked, - guint_t lda, - double* b, - guint_t ldb, - double* c, - guint_t ldc) -{ - gint_t p; - __m256d av0; - __m256d bv0; - __m256d cv0; - __m256d cx0; - double* pb, * pc; - - pb = b; - pc = c; - - for (; j <= (n - 1); j += 1) { - cv0 = _mm256_loadu_pd(pc); cx0 = _mm256_loadu_pd(pc + 4); - double* x = aPacked; - double* pb0 = pb; - for (p = 0; p < k; p += 1) { - bv0 = _mm256_broadcast_sd(pb0); pb0++; - - av0 = _mm256_loadu_pd(x); x += 4; - cv0 = _mm256_fmadd_pd(av0, bv0, cv0); - - av0 = _mm256_loadu_pd(x); x += 4; - cx0 = _mm256_fmadd_pd(av0, bv0, cx0); - } - _mm256_storeu_pd(pc, cv0); - _mm256_storeu_pd(pc + 4, cx0); - pc += ldc;pb += ldb; - }// j loop 1 multiple - //printf(" 8x1:j:%d ", j); - return j; -} - -#if 0 -/************************************************************************************************************/ -/************************** dgemm kernels (4mxn) column preffered ******************************************/ -/************************************************************************************************************/ -/* Residue dgemm kernel 4mx10n with single load and store of C matrix block - alpha = +/-1 and beta = +/-1,0 handled while packing.*/ -inc_t bli_sqp_dgemm_kernel_4mx10n( gint_t n, - gint_t k, - gint_t j, - double* aPacked, - guint_t lda, - double* b, - guint_t ldb, - double* c, - guint_t ldc) -{ - gint_t p; - /* incomplete */ - __m256d av0; - __m256d bv0, bv1, bv2, bv3; - __m256d cv0, cv1, cv2, cv3; - __m256d cx0, cx1, cx2, cx3; - __m256d bv4, cv4, cx4; - double* pb, * pc; - - pb = b; - pc = c; - inc_t ldc10 = ldc * 10; inc_t ldb10 = ldb * 10; - - for (j = 0; j <= (n - 10); j += 10) { - - double* pcldc = pc + ldc; double* pcldc2 = pcldc + ldc; double* pcldc3 = pcldc2 + ldc; double* pcldc4 = pcldc3 + ldc; - double* pbldb = pb + ldb; double* pbldb2 = pbldb + ldb; double* pbldb3 = pbldb2 + ldb; double* pbldb4 = pbldb3 + ldb; - -#if BLIS_ENABLE_PREFETCH - _mm_prefetch((char*)(pc), _MM_HINT_T0); - _mm_prefetch((char*)(pcldc), _MM_HINT_T0); - _mm_prefetch((char*)(pcldc2), _MM_HINT_T0); - _mm_prefetch((char*)(pcldc3), _MM_HINT_T0); - _mm_prefetch((char*)(pcldc4), _MM_HINT_T0); - - _mm_prefetch((char*)(aPacked), _MM_HINT_T0); - - _mm_prefetch((char*)(pb), _MM_HINT_T0); - _mm_prefetch((char*)(pbldb), _MM_HINT_T0); - _mm_prefetch((char*)(pbldb2), _MM_HINT_T0); - _mm_prefetch((char*)(pbldb3), _MM_HINT_T0); - _mm_prefetch((char*)(pbldb4), _MM_HINT_T0); -#endif - /* C matrix column major load */ -#if BLIS_LOADFIRST - cv0 = _mm256_loadu_pd(pc); - cv1 = _mm256_loadu_pd(pcldc); - cv2 = _mm256_loadu_pd(pcldc2); - cv3 = _mm256_loadu_pd(pcldc3); - cv4 = _mm256_loadu_pd(pcldc4); -#else - cv0 = _mm256_setzero_pd(); - cv1 = _mm256_setzero_pd(); - cv2 = _mm256_setzero_pd(); - cv3 = _mm256_setzero_pd(); - cv4 = _mm256_setzero_pd(); -#endif - double* x = aPacked; - double* pb0 = pb; - for (p = 0; p < k; p += 1) { - bv0 = _mm256_broadcast_sd(pb0); pb0++; - bv1 = _mm256_broadcast_sd(pbldb); pbldb++; - bv2 = _mm256_broadcast_sd(pbldb2); pbldb2++; - bv3 = _mm256_broadcast_sd(pbldb3);pbldb3++; - bv4 = _mm256_broadcast_sd(pbldb4);pbldb4++; - - av0 = _mm256_loadu_pd(x); x += 4; - cv0 = _mm256_fmadd_pd(av0, bv0, cv0); - cv1 = _mm256_fmadd_pd(av0, bv1, cv1); - cv2 = _mm256_fmadd_pd(av0, bv2, cv2); - cv3 = _mm256_fmadd_pd(av0, bv3, cv3); - cv4 = _mm256_fmadd_pd(av0, bv4, cv4); - - } -#if BLIS_LOADFIRST -#else - bv0 = _mm256_loadu_pd(pc); - cv0 = _mm256_add_pd(cv0, bv0); - - bv2 = _mm256_loadu_pd(pcldc); - cv1 = _mm256_add_pd(cv1, bv2); - - bv0 = _mm256_loadu_pd(pcldc2); - cv2 = _mm256_add_pd(cv2, bv0); - - bv2 = _mm256_loadu_pd(pcldc3); - cv3 = _mm256_add_pd(cv3, bv2); - - bv0 = _mm256_loadu_pd(pcldc4); - cv4 = _mm256_add_pd(cv4, bv0); -#endif - /* C matrix column major store */ - _mm256_storeu_pd(pc, cv0); - _mm256_storeu_pd(pcldc, cv1); - _mm256_storeu_pd(pcldc2, cv2); - _mm256_storeu_pd(pcldc3, cv3); - _mm256_storeu_pd(pcldc4, cv4); - - - pc += ldc10;pb += ldb10; - } - - return j; -} - -/* residue dgemm kernel 4mx1n with single load and store of C matrix block - Code could be optimized further, complete ymm register set is not used. - Being residue kernel, its of lesser priority. -*/ -inc_t bli_sqp_dgemm_kernel_4mx1n( gint_t n, - gint_t k, - gint_t j, - double* aPacked, - guint_t lda, - double* b, - guint_t ldb, - double* c, - guint_t ldc) -{ - gint_t p; - __m256d av0; - __m256d bv0; - __m256d cv0; - double* pb, * pc; - - pb = b; - pc = c; - - for (; j <= (n - 1); j += 1) { - cv0 = _mm256_loadu_pd(pc); - double* x = aPacked; - double* pb0 = pb; - for (p = 0; p < k; p += 1) { - bv0 = _mm256_broadcast_sd(pb0); pb0++; - av0 = _mm256_loadu_pd(x); x += 4; - cv0 = _mm256_fmadd_pd(av0, bv0, cv0); - } - _mm256_storeu_pd(pc, cv0); - pc += ldc;pb += ldb; - }// j loop 1 multiple - return j; -} - -#endif -/************************************************************************************************************/ -/************************** dgemm kernels (1mxn) column preffered ******************************************/ -/************************************************************************************************************/ - -/* residue dgemm kernel 1mx1n with single load and store of C matrix block - Code could be optimized further, complete ymm register set is not used. - Being residue kernel, its of lesser priority. -*/ -inc_t bli_sqp_dgemm_kernel_1mx1n( gint_t n, - gint_t k, - gint_t j, - double* aPacked, - guint_t lda, - double* b, - guint_t ldb, - double* c, - guint_t ldc) -{ - gint_t p; - double a0; - double b0; - double c0; - double* pb, * pc; - - pb = b; - pc = c; - - for (; j <= (n - 1); j += 1) { - c0 = *pc; - double* x = aPacked; - double* pb0 = pb; - for (p = 0; p < k; p += 1) { - b0 = *pb0; pb0++; - a0 = *x; x++; - c0 += (a0 * b0); - } - *pc = c0; - pc += ldc;pb += ldb; - }// j loop 1 multiple - //printf(" 1x1:j:%d ", j); - return j; -} - -inc_t bli_sqp_dgemm_kernel_mxn( gint_t n, - gint_t k, - gint_t j, - double* aPacked, - guint_t lda, - double* b, - guint_t ldb, - double* c, - guint_t ldc, - gint_t mx) -{ - gint_t p; - double cx[7]; - - double* pb, * pc; - - pb = b; - pc = c; - - for (; j <= (n - 1); j += 1) { - //cv0 = _mm256_loadu_pd(pc); - for (int i = 0; i < mx; i++) - { - cx[i] = *(pc + i); - } - - double* x = aPacked; - double* pb0 = pb; - for (p = 0; p < k; p += 1) { - //bv0 = _mm256_broadcast_sd(pb0); - double b0 = *pb0; - pb0++; - for (int i = 0; i < mx; i++) - { - cx[i] += (*(x + i)) * b0;//cv0 = _mm256_fmadd_pd(av0, bv0, cv0); - } - //av0 = _mm256_loadu_pd(x); - x += mx; - } - //_mm256_storeu_pd(pc, cv0); - for (int i = 0; i < mx; i++) - { - *(pc + i) = cx[i]; - } - pc += ldc;pb += ldb; - }// j loop 1 multiple - //printf(" mx1:j:%d ", j); - return j; -} - -void bli_sqp_prepackA( double* pa, - double* aPacked, - gint_t k, - guint_t lda, - bool isTransA, - double alpha, - gint_t mx) -{ - //printf(" pmx:%d ",mx); - if(mx==8) - { - bli_prepackA_8(pa,aPacked,k, lda,isTransA, alpha); - } - else if(mx==4) - { - bli_prepackA_4(pa,aPacked,k, lda,isTransA, alpha); - } - else if(mx>4) - { - bli_prepackA_G4(pa,aPacked,k, lda,isTransA, alpha, mx); - } - else - { - bli_prepackA_L4(pa,aPacked,k, lda,isTransA, alpha, mx); - } -} - -/* Ax8 packing subroutine */ -void bli_prepackA_8(double* pa, - double* aPacked, - gint_t k, - guint_t lda, - bool isTransA, - double alpha) -{ - __m256d av0, av1, ymm0; - if(isTransA==false) - { - if(alpha==1.0) - { - for (gint_t p = 0; p < k; p += 1) { - av0 = _mm256_loadu_pd(pa); av1 = _mm256_loadu_pd(pa + 4); pa += lda; - _mm256_storeu_pd(aPacked, av0); _mm256_storeu_pd(aPacked + 4, av1); - aPacked += BLIS_MX8; - } - } - else if(alpha==-1.0) - { - ymm0 = _mm256_setzero_pd();//set zero - for (gint_t p = 0; p < k; p += 1) { - av0 = _mm256_loadu_pd(pa); av1 = _mm256_loadu_pd(pa + 4); pa += lda; - av0 = _mm256_sub_pd(ymm0,av0); av1 = _mm256_sub_pd(ymm0,av1); // a = 0 - a; - _mm256_storeu_pd(aPacked, av0); _mm256_storeu_pd(aPacked + 4, av1); - aPacked += BLIS_MX8; - } - } - } - else //subroutine below to be optimized - { - if(alpha==1.0) - { - //A Transpose case: - for (gint_t i = 0; i < BLIS_MX8 ; i++) - { - gint_t idx = i * lda; - for (gint_t p = 0; p < k; p ++) - { - double ar_ = *(pa+idx+p); - gint_t sidx = p * BLIS_MX8; - *(aPacked + sidx + i) = ar_; - } - } - } - else if(alpha==-1.0) - { - //A Transpose case: - for (gint_t i = 0; i < BLIS_MX8 ; i++) - { - gint_t idx = i * lda; - for (gint_t p = 0; p < k; p ++) - { - double ar_ = *(pa+idx+p); - gint_t sidx = p * BLIS_MX8; - *(aPacked + sidx + i) = -ar_; - } - } - } - } -} - -/* Ax4 packing subroutine */ -void bli_prepackA_4(double* pa, - double* aPacked, - gint_t k, - guint_t lda, - bool isTransA, - double alpha) -{ - __m256d av0, ymm0; - if(isTransA==false) - { - if(alpha==1.0) - { - for (gint_t p = 0; p < k; p += 1) { - av0 = _mm256_loadu_pd(pa); pa += lda; - _mm256_storeu_pd(aPacked, av0); - aPacked += BLIS_MX4; - } - } - else if(alpha==-1.0) - { - ymm0 = _mm256_setzero_pd();//set zero - for (gint_t p = 0; p < k; p += 1) { - av0 = _mm256_loadu_pd(pa); pa += lda; - av0 = _mm256_sub_pd(ymm0,av0); // a = 0 - a; - _mm256_storeu_pd(aPacked, av0); - aPacked += BLIS_MX4; - } - } - } - else //subroutine below to be optimized - { - if(alpha==1.0) - { - //A Transpose case: - for (gint_t i = 0; i < BLIS_MX4 ; i++) - { - gint_t idx = i * lda; - for (gint_t p = 0; p < k; p ++) - { - double ar_ = *(pa+idx+p); - gint_t sidx = p * BLIS_MX4; - *(aPacked + sidx + i) = ar_; - } - } - } - else if(alpha==-1.0) - { - //A Transpose case: - for (gint_t i = 0; i < BLIS_MX4 ; i++) - { - gint_t idx = i * lda; - for (gint_t p = 0; p < k; p ++) - { - double ar_ = *(pa+idx+p); - gint_t sidx = p * BLIS_MX4; - *(aPacked + sidx + i) = -ar_; - } - } - } - } - -} - -/* A packing m>4 subroutine */ -void bli_prepackA_G4( double* pa, - double* aPacked, - gint_t k, - guint_t lda, - bool isTransA, - double alpha, - gint_t mx) -{ - __m256d av0, ymm0; - gint_t mrem = mx - 4; - - if(isTransA==false) - { - if(alpha==1.0) - { - for (gint_t p = 0; p < k; p += 1) { - av0 = _mm256_loadu_pd(pa); - _mm256_storeu_pd(aPacked, av0); - for (gint_t i = 0; i < mrem; i += 1) { - *(aPacked+4+i) = *(pa+4+i); - } - aPacked += mx;pa += lda; - } - } - else if(alpha==-1.0) - { - ymm0 = _mm256_setzero_pd();//set zero - for (gint_t p = 0; p < k; p += 1) { - av0 = _mm256_loadu_pd(pa); - av0 = _mm256_sub_pd(ymm0,av0); // a = 0 - a; - _mm256_storeu_pd(aPacked, av0); - for (gint_t i = 0; i < mrem; i += 1) { - *(aPacked+4+i) = -*(pa+4+i); - } - aPacked += mx;pa += lda; - } - } - } - else //subroutine below to be optimized - { - if(alpha==1.0) - { - //A Transpose case: - for (gint_t i = 0; i < mx ; i++) - { - gint_t idx = i * lda; - for (gint_t p = 0; p < k; p ++) - { - double ar_ = *(pa+idx+p); - gint_t sidx = p * mx; - *(aPacked + sidx + i) = ar_; - } - } - } - else if(alpha==-1.0) - { - //A Transpose case: - for (gint_t i = 0; i < mx ; i++) - { - gint_t idx = i * lda; - for (gint_t p = 0; p < k; p ++) - { - double ar_ = *(pa+idx+p); - gint_t sidx = p * mx; - *(aPacked + sidx + i) = -ar_; - } - } - } - } - -} - -/* A packing m<4 subroutine */ -void bli_prepackA_L4( double* pa, - double* aPacked, - gint_t k, - guint_t lda, - bool isTransA, - double alpha, - gint_t mx) -{ - if(isTransA==false) - { - if(alpha==1.0) - { - for (gint_t p = 0; p < k; p += 1) - { - for (gint_t i = 0; i < mx; i += 1) - { - *(aPacked+i) = *(pa+i); - } - aPacked += mx;pa += lda; - } - } - else if(alpha==-1.0) - { - for (gint_t p = 0; p < k; p += 1) - { - for (gint_t i = 0; i < mx; i += 1) - { - *(aPacked+i) = -*(pa+i); - } - aPacked += mx;pa += lda; - } - } - } - else - { - if(alpha==1.0) - { - //A Transpose case: - for (gint_t i = 0; i < mx ; i++) - { - gint_t idx = i * lda; - for (gint_t p = 0; p < k; p ++) - { - double ar_ = *(pa+idx+p); - gint_t sidx = p * mx; - *(aPacked + sidx + i) = ar_; - } - } - } - else if(alpha==-1.0) - { - //A Transpose case: - for (gint_t i = 0; i < mx ; i++) - { - gint_t idx = i * lda; - for (gint_t p = 0; p < k; p ++) - { - double ar_ = *(pa+idx+p); - gint_t sidx = p * mx; - *(aPacked + sidx + i) = -ar_; - } - } - } - } - - -} - -/* Ax1 packing subroutine */ -void bli_prepackA_1(double* pa, - double* aPacked, - gint_t k, - guint_t lda, - bool isTransA, - double alpha) -{ - if(isTransA==false) - { - if(alpha==1.0) - { - for (gint_t p = 0; p < k; p += 1) { - *aPacked = *pa; - pa += lda; - aPacked++; - } - } - else if(alpha==-1.0) - { - for (gint_t p = 0; p < k; p += 1) { - *aPacked = -(*pa); - pa += lda; - aPacked++; - } - } - } - else - { - if(alpha==1.0) - { - //A Transpose case: - for (gint_t p = 0; p < k; p ++) - { - double ar_ = *(pa+p); - *(aPacked + p) = ar_; - } - } - else if(alpha==-1.0) - { - //A Transpose case: - for (gint_t p = 0; p < k; p ++) - { - double ar_ = *(pa+p); - *(aPacked + p) = -ar_; - } - } - } -} - - -void bli_add_m( gint_t m, - gint_t n, - double* w, - double* c) -{ - double* pc = c; - double* pw = w; - gint_t count = m*n; - gint_t i = 0; - __m256d cv0, wv0; - - for (; i <= (count-4); i+=4) - { - cv0 = _mm256_loadu_pd(pc); - wv0 = _mm256_loadu_pd(pw); pw += 4; - cv0 = _mm256_add_pd(cv0,wv0); - _mm256_storeu_pd(pc, cv0); pc += 4; - } - for (; i < count; i++) - { - *pc = *pc + *pw; - pc++; pw++; - } -} - -void bli_sub_m( gint_t m, - gint_t n, - double* w, - double* c) -{ - double* pc = c; - double* pw = w; - gint_t count = m*n; - gint_t i = 0; - __m256d cv0, wv0; - - for (; i <= (count-4); i+=4) - { - cv0 = _mm256_loadu_pd(pc); - wv0 = _mm256_loadu_pd(pw); pw += 4; - cv0 = _mm256_sub_pd(cv0,wv0); - _mm256_storeu_pd(pc, cv0); pc += 4; - } - for (; i < count; i++) - { - *pc = *pc - *pw; - pc++; pw++; - } -} - -/* Pack real and imaginary parts in separate buffers and also multipy with multiplication factor */ -void bli_3m_sqp_packC_real_imag(double* pc, - guint_t n, - guint_t m, - guint_t ldc, - double* pcr, - double* pci, - double mul, - gint_t mx) -{ - gint_t j, p; - __m256d av0, av1, zerov; - __m256d tv0, tv1; - gint_t max_m = (m*2)-8; - - if((mul ==1.0)||(mul==-1.0)) - { - if(mul ==1.0) /* handles alpha or beta = 1.0 */ - { - for (j = 0; j < n; j++) - { - for (p = 0; p <= max_m; p += 8) - { - double* pbp = pc + p; - av0 = _mm256_loadu_pd(pbp); //ai1, ar1, ai0, ar0 - av1 = _mm256_loadu_pd(pbp+4); //ai3, ar3, ai2, ar2 - - tv0 = _mm256_permute2f128_pd(av0, av1, 0x20);//ai2, ar2, ai0, ar0 - tv1 = _mm256_permute2f128_pd(av0, av1, 0x31);//ai3, ar3, ai1, ar1 - av0 = _mm256_unpacklo_pd(tv0, tv1);//ar3, ar2, ar1, ar0 - av1 = _mm256_unpackhi_pd(tv0, tv1);//ai3, ai2, ai1, ai0 - - _mm256_storeu_pd(pcr, av0); pcr += 4; - _mm256_storeu_pd(pci, av1); pci += 4; - } - - for (; p < (m*2); p += 2)// (real + imag)*m - { - double br = *(pc + p) ; - double bi = *(pc + p + 1); - *pcr = br; - *pci = bi; - pcr++; pci++; - } - pc = pc + ldc; - } - } - else /* handles alpha or beta = - 1.0 */ - { - zerov = _mm256_setzero_pd(); - for (j = 0; j < n; j++) - { - for (p = 0; p <= max_m; p += 8) - { - double* pbp = pc + p; - av0 = _mm256_loadu_pd(pbp); //ai1, ar1, ai0, ar0 - av1 = _mm256_loadu_pd(pbp+4);//ai3, ar3, ai2, ar2 - - tv0 = _mm256_permute2f128_pd(av0, av1, 0x20);//ai2, ar2, ai0, ar0 - tv1 = _mm256_permute2f128_pd(av0, av1, 0x31);//ai3, ar3, ai1, ar1 - av0 = _mm256_unpacklo_pd(tv0, tv1);//ar3, ar2, ar1, ar0 - av1 = _mm256_unpackhi_pd(tv0, tv1);//ai3, ai2, ai1, ai0 - - //negate - av0 = _mm256_sub_pd(zerov,av0); - av1 = _mm256_sub_pd(zerov,av1); - - _mm256_storeu_pd(pcr, av0); pcr += 4; - _mm256_storeu_pd(pci, av1); pci += 4; - } - - for (; p < (m*2); p += 2)// (real + imag)*m - { - double br = -*(pc + p) ; - double bi = -*(pc + p + 1); - *pcr = br; - *pci = bi; - pcr++; pci++; - } - pc = pc + ldc; - } - } - } - else if(mul==0) /* handles alpha or beta is equal to zero */ - { - double br_ = 0; - double bi_ = 0; - for (j = 0; j < n; j++) - { - for (p = 0; p < (m*2); p += 2)// (real + imag)*m - { - *pcr = br_; - *pci = bi_; - pcr++; pci++; - } - pc = pc + ldc; - } - } - else /* handles alpha or beta is not equal +/- 1.0 and zero */ - { - for (j = 0; j < n; j++) - { - for (p = 0; p < (m*2); p += 2)// (real + imag)*m - { - double br_ = mul * (*(pc + p)); - double bi_ = mul * (*(pc + p + 1)); - *pcr = br_; - *pci = bi_; - pcr++; pci++; - } - pc = pc + ldc; - } - } -} - -/* Pack real and imaginary parts in separate buffers and compute sum of real and imaginary part */ -void bli_3m_sqp_packB_real_imag_sum(double* pb, - guint_t n, - guint_t k, - guint_t ldb, - double* pbr, - double* pbi, - double* pbs, - double mul, - gint_t mx) -{ - gint_t j, p; - __m256d av0, av1, zerov; - __m256d tv0, tv1, sum; - gint_t max_k = (k*2) - 8; - if((mul ==1.0)||(mul==-1.0)) - { - if(mul ==1.0) - { - for (j = 0; j < n; j++) - { - for (p=0; p <= max_k; p += 8) - { - double* pbp = pb + p; - av0 = _mm256_loadu_pd(pbp);//ai1, ar1, ai0, ar0 - av1 = _mm256_loadu_pd(pbp+4);//ai3, ar3, ai2, ar2 - - tv0 = _mm256_permute2f128_pd(av0, av1, 0x20);//ai2, ar2, ai0, ar0 - tv1 = _mm256_permute2f128_pd(av0, av1, 0x31);//ai3, ar3, ai1, ar1 - av0 = _mm256_unpacklo_pd(tv0, tv1);//ar3, ar2, ar1, ar0 - av1 = _mm256_unpackhi_pd(tv0, tv1);//ai3, ai2, ai1, ai0 - sum = _mm256_add_pd(av0, av1); - _mm256_storeu_pd(pbr, av0); pbr += 4; - _mm256_storeu_pd(pbi, av1); pbi += 4; - _mm256_storeu_pd(pbs, sum); pbs += 4; - } - - for (; p < (k*2); p += 2)// (real + imag)*k - { - double br = *(pb + p) ; - double bi = *(pb + p + 1); - *pbr = br; - *pbi = bi; - *pbs = br + bi; - - pbr++; pbi++; pbs++; - } - pb = pb + ldb; - } - } - else - { - zerov = _mm256_setzero_pd(); - for (j = 0; j < n; j++) - { - for (p = 0; p <= max_k; p += 8) - { - double* pbp = pb + p; - av0 = _mm256_loadu_pd(pbp);//ai1, ar1, ai0, ar0 - av1 = _mm256_loadu_pd(pbp+4);//ai3, ar3, ai2, ar2 - - tv0 = _mm256_permute2f128_pd(av0, av1, 0x20);//ai2, ar2, ai0, ar0 - tv1 = _mm256_permute2f128_pd(av0, av1, 0x31);//ai3, ar3, ai1, ar1 - av0 = _mm256_unpacklo_pd(tv0, tv1);//ar3, ar2, ar1, ar0 - av1 = _mm256_unpackhi_pd(tv0, tv1);//ai3, ai2, ai1, ai0 - - //negate - av0 = _mm256_sub_pd(zerov,av0); - av1 = _mm256_sub_pd(zerov,av1); - - sum = _mm256_add_pd(av0, av1); - _mm256_storeu_pd(pbr, av0); pbr += 4; - _mm256_storeu_pd(pbi, av1); pbi += 4; - _mm256_storeu_pd(pbs, sum); pbs += 4; - } - - for (; p < (k*2); p += 2)// (real + imag)*k - { - double br = -*(pb + p) ; - double bi = -*(pb + p + 1); - *pbr = br; - *pbi = bi; - *pbs = br + bi; - - pbr++; pbi++; pbs++; - } - pb = pb + ldb; - } - } - } - else - { - for (j = 0; j < n; j++) - { - for (p = 0; p < (k*2); p += 2)// (real + imag)*k - { - double br_ = mul * (*(pb + p)); - double bi_ = mul * (*(pb + p + 1)); - *pbr = br_; - *pbi = bi_; - *pbs = br_ + bi_; - - pbr++; pbi++; pbs++; - } - pb = pb + ldb; - } - } -} - -/* Pack real and imaginary parts of A matrix in separate buffers and compute sum of real and imaginary part */ -void bli_3m_sqp_packA_real_imag_sum(double *pa, - gint_t i, - guint_t k, - guint_t lda, - double *par, - double *pai, - double *pas, - trans_t transa, - gint_t mx, - gint_t p) -{ - __m256d av0, av1, av2, av3; - __m256d tv0, tv1, sum, zerov; - gint_t poffset = p; -#if KLP -#endif - if(mx==8) - { - if(transa == BLIS_NO_TRANSPOSE) - { - pa = pa +i; -#if KLP - pa = pa + (p*lda); -#else - p = 0; -#endif - for (; p < k; p += 1) - { - //for (int ii = 0; ii < MX8 * 2; ii += 2) //real + imag : Rkernel needs 8 elements each. - av0 = _mm256_loadu_pd(pa); - av1 = _mm256_loadu_pd(pa+4); - av2 = _mm256_loadu_pd(pa+8); - av3 = _mm256_loadu_pd(pa+12); - - tv0 = _mm256_permute2f128_pd(av0, av1, 0x20); - tv1 = _mm256_permute2f128_pd(av0, av1, 0x31); - av0 = _mm256_unpacklo_pd(tv0, tv1); - av1 = _mm256_unpackhi_pd(tv0, tv1); - sum = _mm256_add_pd(av0, av1); - _mm256_storeu_pd(par, av0); par += 4; - _mm256_storeu_pd(pai, av1); pai += 4; - _mm256_storeu_pd(pas, sum); pas += 4; - - tv0 = _mm256_permute2f128_pd(av2, av3, 0x20); - tv1 = _mm256_permute2f128_pd(av2, av3, 0x31); - av2 = _mm256_unpacklo_pd(tv0, tv1); - av3 = _mm256_unpackhi_pd(tv0, tv1); - sum = _mm256_add_pd(av2, av3); - _mm256_storeu_pd(par, av2); par += 4; - _mm256_storeu_pd(pai, av3); pai += 4; - _mm256_storeu_pd(pas, sum); pas += 4; - - pa = pa + lda; - } - } - else if(transa == BLIS_CONJ_NO_TRANSPOSE) - { - zerov = _mm256_setzero_pd(); - pa = pa +i; -#if KLP - pa = pa + (p*lda); -#else - p = 0; -#endif - for (; p < k; p += 1) - { - //for (int ii = 0; ii < MX8 * 2; ii += 2) //real + imag : Rkernel needs 8 elements each. - av0 = _mm256_loadu_pd(pa); - av1 = _mm256_loadu_pd(pa+4); - av2 = _mm256_loadu_pd(pa+8); - av3 = _mm256_loadu_pd(pa+12); - - tv0 = _mm256_permute2f128_pd(av0, av1, 0x20); - tv1 = _mm256_permute2f128_pd(av0, av1, 0x31); - av0 = _mm256_unpacklo_pd(tv0, tv1); - av1 = _mm256_unpackhi_pd(tv0, tv1); - av1 = _mm256_sub_pd(zerov,av1);//negate imaginary component - sum = _mm256_add_pd(av0, av1); - _mm256_storeu_pd(par, av0); par += 4; - _mm256_storeu_pd(pai, av1); pai += 4; - _mm256_storeu_pd(pas, sum); pas += 4; - - tv0 = _mm256_permute2f128_pd(av2, av3, 0x20); - tv1 = _mm256_permute2f128_pd(av2, av3, 0x31); - av2 = _mm256_unpacklo_pd(tv0, tv1); - av3 = _mm256_unpackhi_pd(tv0, tv1); - av3 = _mm256_sub_pd(zerov,av3);//negate imaginary component - sum = _mm256_add_pd(av2, av3); - _mm256_storeu_pd(par, av2); par += 4; - _mm256_storeu_pd(pai, av3); pai += 4; - _mm256_storeu_pd(pas, sum); pas += 4; - - pa = pa + lda; - } - } - else if(transa == BLIS_TRANSPOSE) - { - gint_t idx = (i/2) * lda; - pa = pa + idx; -#if KLP -#else - p = 0; -#endif - //A Transpose case: - for (gint_t ii = 0; ii < BLIS_MX8 ; ii++) - { - gint_t idx = ii * lda; - gint_t sidx; - gint_t pidx = 0; - gint_t max_k = (k*2) - 8; - for (p = poffset; p <= max_k; p += 8) - { - double ar0_ = *(pa + idx + p); - double ai0_ = *(pa + idx + p + 1); - - double ar1_ = *(pa + idx + p + 2); - double ai1_ = *(pa + idx + p + 3); - - double ar2_ = *(pa + idx + p + 4); - double ai2_ = *(pa + idx + p + 5); - - double ar3_ = *(pa + idx + p + 6); - double ai3_ = *(pa + idx + p + 7); - - sidx = (pidx/2) * BLIS_MX8; - *(par + sidx + ii) = ar0_; - *(pai + sidx + ii) = ai0_; - *(pas + sidx + ii) = ar0_ + ai0_; - - sidx = ((pidx+2)/2) * BLIS_MX8; - *(par + sidx + ii) = ar1_; - *(pai + sidx + ii) = ai1_; - *(pas + sidx + ii) = ar1_ + ai1_; - - sidx = ((pidx+4)/2) * BLIS_MX8; - *(par + sidx + ii) = ar2_; - *(pai + sidx + ii) = ai2_; - *(pas + sidx + ii) = ar2_ + ai2_; - - sidx = ((pidx+6)/2) * BLIS_MX8; - *(par + sidx + ii) = ar3_; - *(pai + sidx + ii) = ai3_; - *(pas + sidx + ii) = ar3_ + ai3_; - pidx += 8; - - } - - for (; p < (k*2); p += 2) - { - double ar_ = *(pa + idx + p); - double ai_ = *(pa + idx + p + 1); - gint_t sidx = (pidx/2) * BLIS_MX8; - *(par + sidx + ii) = ar_; - *(pai + sidx + ii) = ai_; - *(pas + sidx + ii) = ar_ + ai_; - pidx += 2; - } - } - } - else if(transa == BLIS_CONJ_TRANSPOSE) - { - gint_t idx = (i/2) * lda; - pa = pa + idx; -#if KLP -#else - p = 0; -#endif - //A conjugate Transpose case: - for (gint_t ii = 0; ii < BLIS_MX8 ; ii++) - { - gint_t idx = ii * lda; - gint_t sidx; - gint_t pidx = 0; - gint_t max_k = (k*2) - 8; - for (p = poffset; p <= max_k; p += 8) - { - double ar0_ = *(pa + idx + p); - double ai0_ = -(*(pa + idx + p + 1)); - - double ar1_ = *(pa + idx + p + 2); - double ai1_ = -(*(pa + idx + p + 3)); - - double ar2_ = *(pa + idx + p + 4); - double ai2_ = -(*(pa + idx + p + 5)); - - double ar3_ = *(pa + idx + p + 6); - double ai3_ = -(*(pa + idx + p + 7)); - - sidx = (pidx/2) * BLIS_MX8; - *(par + sidx + ii) = ar0_; - *(pai + sidx + ii) = ai0_; - *(pas + sidx + ii) = ar0_ + ai0_; - - sidx = ((pidx+2)/2) * BLIS_MX8; - *(par + sidx + ii) = ar1_; - *(pai + sidx + ii) = ai1_; - *(pas + sidx + ii) = ar1_ + ai1_; - - sidx = ((pidx+4)/2) * BLIS_MX8; - *(par + sidx + ii) = ar2_; - *(pai + sidx + ii) = ai2_; - *(pas + sidx + ii) = ar2_ + ai2_; - - sidx = ((pidx+6)/2) * BLIS_MX8; - *(par + sidx + ii) = ar3_; - *(pai + sidx + ii) = ai3_; - *(pas + sidx + ii) = ar3_ + ai3_; - pidx += 8; - } - - for (; p < (k*2); p += 2) - { - double ar_ = *(pa + idx + p); - double ai_ = -(*(pa + idx + p + 1)); - gint_t sidx = (pidx/2) * BLIS_MX8; - *(par + sidx + ii) = ar_; - *(pai + sidx + ii) = ai_; - *(pas + sidx + ii) = ar_ + ai_; - pidx += 2; - } - } - } - } //mx==8 - else//mx==1 - { - if(transa == BLIS_NO_TRANSPOSE) - { - pa = pa + i; -#if KLP -#else - p = 0; -#endif - //A No transpose case: - for (; p < k; p += 1) - { - gint_t idx = p * lda; - for (gint_t ii = 0; ii < (mx*2) ; ii += 2) - { //real + imag : Rkernel needs 8 elements each. - double ar_ = *(pa + idx + ii); - double ai_ = *(pa + idx + ii + 1); - *par = ar_; - *pai = ai_; - *pas = ar_ + ai_; - par++; pai++; pas++; - } - } - } - else if(transa == BLIS_CONJ_NO_TRANSPOSE) - { - pa = pa + i; -#if KLP -#else - p = 0; -#endif - //A conjuate No transpose case: - for (; p < k; p += 1) - { - gint_t idx = p * lda; - for (gint_t ii = 0; ii < (mx*2) ; ii += 2) - { //real + imag : Rkernel needs 8 elements each. - double ar_ = *(pa + idx + ii); - double ai_ = -(*(pa + idx + ii + 1));// conjugate: negate imaginary component - *par = ar_; - *pai = ai_; - *pas = ar_ + ai_; - par++; pai++; pas++; - } - } - } - else if(transa == BLIS_TRANSPOSE) - { - gint_t idx = (i/2) * lda; - pa = pa + idx; -#if KLP -#else - p = 0; -#endif - //A Transpose case: - for (gint_t ii = 0; ii < mx ; ii++) - { - gint_t idx = ii * lda; - gint_t sidx; - gint_t pidx = 0; - for (p = poffset;p < (k*2); p += 2) - { - double ar0_ = *(pa + idx + p); - double ai0_ = *(pa + idx + p + 1); - - sidx = (pidx/2) * mx; - *(par + sidx + ii) = ar0_; - *(pai + sidx + ii) = ai0_; - *(pas + sidx + ii) = ar0_ + ai0_; - pidx += 2; - - } - } - } - else if(transa == BLIS_CONJ_TRANSPOSE) - { - gint_t idx = (i/2) * lda; - pa = pa + idx; -#if KLP -#else - p = 0; -#endif - //A Transpose case: - for (gint_t ii = 0; ii < mx ; ii++) - { - gint_t idx = ii * lda; - gint_t sidx; - gint_t pidx = 0; - for (p = poffset;p < (k*2); p += 2) - { - double ar0_ = *(pa + idx + p); - double ai0_ = -(*(pa + idx + p + 1)); - - sidx = (pidx/2) * mx; - *(par + sidx + ii) = ar0_; - *(pai + sidx + ii) = ai0_; - *(pas + sidx + ii) = ar0_ + ai0_; - pidx += 2; - - } - } - } - }//mx==1 -} - diff --git a/kernels/zen/3/bli_gemm_sqp_kernels.h b/kernels/zen/3/bli_gemm_sqp_kernels.h deleted file mode 100644 index 588981fad0..0000000000 --- a/kernels/zen/3/bli_gemm_sqp_kernels.h +++ /dev/null @@ -1,65 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2021, Advanced Micro Devices, Inc. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ -/* square packed (sqp) kernels */ -#define KLP 1// k loop partition. - -/* sqp dgemm core kernels, targetted mainly for square sizes by default. - sqp framework allows tunning for other shapes.*/ -inc_t bli_sqp_dgemm_kernel_8mx6n(gint_t n, gint_t k, gint_t j, double* aPacked, guint_t lda, double* b, guint_t ldb, double* c, guint_t ldc); -inc_t bli_sqp_dgemm_kernel_8mx5n(gint_t n, gint_t k, gint_t j, double* aPacked, guint_t lda, double* b, guint_t ldb, double* c, guint_t ldc); -inc_t bli_sqp_dgemm_kernel_8mx4n(gint_t n, gint_t k, gint_t j, double* aPacked, guint_t lda, double* b, guint_t ldb, double* c, guint_t ldc); -inc_t bli_sqp_dgemm_kernel_8mx3n(gint_t n, gint_t k, gint_t j, double* aPacked, guint_t lda, double* b, guint_t ldb, double* c, guint_t ldc); -inc_t bli_sqp_dgemm_kernel_8mx2n(gint_t n, gint_t k, gint_t j, double* aPacked, guint_t lda, double* b, guint_t ldb, double* c, guint_t ldc); -inc_t bli_sqp_dgemm_kernel_8mx1n(gint_t n, gint_t k, gint_t j, double* aPacked, guint_t lda, double* b, guint_t ldb, double* c, guint_t ldc); -inc_t bli_sqp_dgemm_kernel_1mx1n(gint_t n, gint_t k, gint_t j, double* aPacked, guint_t lda, double* b, guint_t ldb, double* c, guint_t ldc); -inc_t bli_sqp_dgemm_kernel_mxn(gint_t n, gint_t k, gint_t j, double* aPacked, guint_t lda, double* b, guint_t ldb, double* c, guint_t ldc, gint_t mx); - -//add and sub kernels -void bli_add_m(gint_t m,gint_t n,double* w,double* c); -void bli_sub_m(gint_t m, gint_t n, double* w, double* c); - -//packing kernels -//Pack A with alpha multiplication -void bli_sqp_prepackA(double* pa, double* aPacked, gint_t k, guint_t lda, bool isTransA, double alpha, gint_t mx); - -void bli_prepackA_8(double* pa, double* aPacked, gint_t k, guint_t lda, bool isTransA, double alpha); -void bli_prepackA_4(double* pa, double* aPacked, gint_t k, guint_t lda, bool isTransA, double alpha); -void bli_prepackA_G4(double* pa, double* aPacked, gint_t k, guint_t lda, bool isTransA, double alpha, gint_t mx); -void bli_prepackA_L4(double* pa, double* aPacked, gint_t k, guint_t lda, bool isTransA, double alpha, gint_t mx); -void bli_prepackA_1(double* pa, double* aPacked, gint_t k, guint_t lda, bool isTransA, double alpha); - -/* Pack real and imaginary parts in separate buffers and also multipy with multiplication factor */ -void bli_3m_sqp_packC_real_imag(double* pb, guint_t n, guint_t k, guint_t ldb, double* pbr, double* pbi, double mul, gint_t mx); -void bli_3m_sqp_packB_real_imag_sum(double* pb, guint_t n, guint_t k, guint_t ldb, double* pbr, double* pbi, double* pbs, double mul, gint_t mx); -void bli_3m_sqp_packA_real_imag_sum(double *pa, gint_t i, guint_t k, guint_t lda, double *par, double *pai, double *pas, trans_t transa, gint_t mx, gint_t p); \ No newline at end of file diff --git a/kernels/zen/3/bli_trsm_small.c b/kernels/zen/3/bli_trsm_small.c index bb8a2e9cc5..f5f7f37c6f 100644 --- a/kernels/zen/3/bli_trsm_small.c +++ b/kernels/zen/3/bli_trsm_small.c @@ -668,9 +668,11 @@ BLIS_INLINE err_t dtrsm_XAltB_ref ymm15 = _mm256_setzero_pd(); /*GEMM block used in trsm small right cases*/ +/* B = 8x6, A = 6x6 */ #define BLIS_DTRSM_SMALL_GEMM_6nx8m(a01,b10,cs_b,p_lda,k_iter) \ for(k = 0; k < k_iter; k++) \ {\ + _mm_prefetch((char*)( a01 + 8), _MM_HINT_T0); \ /*load 8x1 block of B10*/ \ ymm0 = _mm256_loadu_pd((double const *)b10); \ ymm1 = _mm256_loadu_pd((double const *)(b10 + 4)); \ @@ -709,7 +711,7 @@ BLIS_INLINE err_t dtrsm_XAltB_ref #define BLIS_DTRSM_SMALL_GEMM_6nx4m(a01,b10,cs_b,p_lda,k_iter) \ for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ {\ - /*load 8x1 block of B10*/\ + /*load 4x1 block of B10*/\ ymm0 = _mm256_loadu_pd((double const *)b10); /*B10[0][0] B10[1][0] B10[2][0] B10[3][0]*/\ \ /*broadcast 1st row of A01*/\ @@ -735,6 +737,96 @@ BLIS_INLINE err_t dtrsm_XAltB_ref b10 += cs_b;\ } +#define BLIS_DTRSM_SMALL_GEMM_6nx3m(a01,b10,cs_b,p_lda,k_iter) \ + for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + /*load 3x1 block of B10*/\ + xmm5 = _mm_loadu_pd((double const*)(b10));\ + ymm0 = _mm256_broadcast_sd((double const *)(b10 + 2));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ +\ + /*broadcast 1st row of A01*/\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ + ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3);\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); /*A01[0][1]*/\ + ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5);\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); /*A01[0][2]*/\ + ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7);\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 3)); /*A01[0][3]*/\ + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9);\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 4)); /*A01[0][4]*/\ + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11);\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 5)); /*A01[0][5]*/\ + ymm13 = _mm256_fmadd_pd(ymm2, ymm0, ymm13);\ +\ + a01 += 1; /*move to next row*/\ + b10 += cs_b;\ + } + +#define BLIS_DTRSM_SMALL_GEMM_6nx2m(a01,b10,cs_b,p_lda,k_iter) \ + for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + /*load 2x1 block of B10*/\ + xmm5 = _mm_loadu_pd((double const*)(b10));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ +\ + /*broadcast 1st row of A01*/\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ + ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3);\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); /*A01[0][1]*/\ + ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5);\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); /*A01[0][2]*/\ + ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7);\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 3)); /*A01[0][3]*/\ + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9);\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 4)); /*A01[0][4]*/\ + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11);\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 5)); /*A01[0][5]*/\ + ymm13 = _mm256_fmadd_pd(ymm2, ymm0, ymm13);\ +\ + a01 += 1; /*move to next row*/\ + b10 += cs_b;\ + } + +#define BLIS_DTRSM_SMALL_GEMM_6nx1m(a01,b10,cs_b,p_lda,k_iter) \ + for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + /*load 1x1 block of B10*/\ + ymm0 = _mm256_broadcast_sd((double const *)b10);\ +\ + /*broadcast 1st row of A01*/\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ + ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3);\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); /*A01[0][1]*/\ + ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5);\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); /*A01[0][2]*/\ + ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7);\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 3)); /*A01[0][3]*/\ + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9);\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 4)); /*A01[0][4]*/\ + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11);\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 5)); /*A01[0][5]*/\ + ymm13 = _mm256_fmadd_pd(ymm2, ymm0, ymm13);\ +\ + a01 += 1; /*move to next row*/\ + b10 += cs_b;\ + } + #define BLIS_DTRSM_SMALL_GEMM_4nx8m(a01,b10,cs_b,p_lda,k_iter) \ for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ {\ @@ -763,6 +855,101 @@ BLIS_INLINE err_t dtrsm_XAltB_ref b10 += cs_b;\ } +#define BLIS_DTRSM_SMALL_GEMM_4nx4m(a01,b10,cs_b,p_lda,k_iter) \ + for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + /*load 4x1 block of B10*/\ + ymm0 = _mm256_loadu_pd((double const *)b10);/*B10[0][0] B10[1][0] B10[2][0] B10[3][0]*/\ +\ + /*broadcast 1st row of A01*/\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ + ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3);\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); /*A01[0][1]*/\ + ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5);\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); /*A01[0][2]*/\ + ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7);\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 3)); /*A01[0][3]*/\ + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9);\ +\ + a01 += 1; /*move to next row*/\ + b10 += cs_b;\ + } + +#define BLIS_DTRSM_SMALL_GEMM_4nx3m(a01,b10,cs_b,p_lda,k_iter) \ + for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + /*load 3x1 block of B10*/\ + xmm5 = _mm_loadu_pd((double const*)(b10));\ + ymm0 = _mm256_broadcast_sd((double const *)(b10 + 2));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ +\ + /*broadcast 1st row of A01*/\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ + ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3);\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); /*A01[0][1]*/\ + ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5);\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); /*A01[0][2]*/\ + ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7);\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 3)); /*A01[0][3]*/\ + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9);\ +\ + a01 += 1; /*move to next row*/\ + b10 += cs_b;\ + } + +#define BLIS_DTRSM_SMALL_GEMM_4nx2m(a01,b10,cs_b,p_lda,k_iter) \ + for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + /*load 2x1 block of B10*/\ + xmm5 = _mm_loadu_pd((double const*)(b10));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ +\ + /*broadcast 1st row of A01*/\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ + ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3);\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); /*A01[0][1]*/\ + ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5);\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); /*A01[0][2]*/\ + ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7);\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 3)); /*A01[0][3]*/\ + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9);\ +\ + a01 += 1; /*move to next row*/\ + b10 += cs_b;\ + } + +#define BLIS_DTRSM_SMALL_GEMM_4nx1m(a01,b10,cs_b,p_lda,k_iter) \ + for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + /*load 1x1 block of B10*/\ + ymm0 = _mm256_broadcast_sd((double const *)b10);\ +\ + /*broadcast 1st row of A01*/\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ + ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3);\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); /*A01[0][1]*/\ + ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5);\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); /*A01[0][2]*/\ + ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7);\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 3)); /*A01[0][3]*/\ + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9);\ +\ + a01 += 1; /*move to next row*/\ + b10 += cs_b;\ + } + #define BLIS_DTRSM_SMALL_GEMM_3nx8m(a01,b10,cs_b,p_lda,k_iter)\ for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ {\ @@ -787,47 +974,54 @@ BLIS_INLINE err_t dtrsm_XAltB_ref b10 += cs_b;\ } -#define BLIS_DTRSM_SMALL_GEMM_2nx8m(a01,b10,cs_b,p_lda,k_iter)\ +#define BLIS_DTRSM_SMALL_GEMM_3nx4m(a01,b10,cs_b,p_lda,k_iter) \ for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ {\ - /*load 8x1 block of B10*/\ + /*load 4x1 block of B10*/\ ymm0 = _mm256_loadu_pd((double const *)b10);/*B10[0][0] B10[1][0] B10[2][0] B10[3][0]*/\ - ymm1 = _mm256_loadu_pd((double const *)(b10 + 4));/*B10[4][0] B10[5][0] B10[6][0] B10[7][0]*/\ \ /*broadcast 1st row of A01*/\ ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3);\ - ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4);\ \ ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); /*A01[0][1]*/\ ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5);\ - ymm6 = _mm256_fmadd_pd(ymm2, ymm1, ymm6);\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); /*A01[0][2]*/\ + ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7);\ \ a01 += 1; /*move to next row*/\ b10 += cs_b;\ } -#define BLIS_DTRSM_SMALL_GEMM_1nx8m(a01,b10,cs_b,p_lda,k_iter)\ +#define BLIS_DTRSM_SMALL_GEMM_3nx3m(a01,b10,cs_b,p_lda,k_iter) \ for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ {\ - /*load 8x1 block of B10*/\ - ymm0 = _mm256_loadu_pd((double const *)b10);/*B10[0][0] B10[1][0] B10[2][0] B10[3][0]*/\ - ymm1 = _mm256_loadu_pd((double const *)(b10 + 4));/*B10[4][0] B10[5][0] B10[6][0] B10[7][0]*/\ + /*load 3x1 block of B10*/\ + xmm5 = _mm_loadu_pd((double const*)(b10));\ + ymm0 = _mm256_broadcast_sd((double const *)(b10 + 2));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ \ /*broadcast 1st row of A01*/\ ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3);\ - ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4);\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); /*A01[0][1]*/\ + ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5);\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); /*A01[0][2]*/\ + ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7);\ \ a01 += 1; /*move to next row*/\ b10 += cs_b;\ } -#define BLIS_DTRSM_SMALL_GEMM_4nx4m(a01,b10,cs_b,p_lda,k_iter) \ +#define BLIS_DTRSM_SMALL_GEMM_3nx2m(a01,b10,cs_b,p_lda,k_iter) \ for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ {\ - /*load 8x1 block of B10*/\ - ymm0 = _mm256_loadu_pd((double const *)b10);/*B10[0][0] B10[1][0] B10[2][0] B10[3][0]*/\ + /*load 2x1 block of B10*/\ + xmm5 = _mm_loadu_pd((double const*)(b10));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ \ /*broadcast 1st row of A01*/\ ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ @@ -838,19 +1032,16 @@ BLIS_INLINE err_t dtrsm_XAltB_ref \ ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); /*A01[0][2]*/\ ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7);\ -\ - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 3)); /*A01[0][3]*/\ - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9);\ \ a01 += 1; /*move to next row*/\ b10 += cs_b;\ } -#define BLIS_DTRSM_SMALL_GEMM_3nx4m(a01,b10,cs_b,p_lda,k_iter) \ +#define BLIS_DTRSM_SMALL_GEMM_3nx1m(a01,b10,cs_b,p_lda,k_iter) \ for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ {\ - /*load 8x1 block of B10*/\ - ymm0 = _mm256_loadu_pd((double const *)b10);/*B10[0][0] B10[1][0] B10[2][0] B10[3][0]*/\ + /*load 1x1 block of B10*/\ + ymm0 = _mm256_broadcast_sd((double const *)b10);\ \ /*broadcast 1st row of A01*/\ ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ @@ -866,28 +1057,163 @@ BLIS_INLINE err_t dtrsm_XAltB_ref b10 += cs_b;\ } -#define BLIS_DTRSM_SMALL_GEMM_2nx4m(a01,b10,cs_b,p_lda,k_iter) \ +#define BLIS_DTRSM_SMALL_GEMM_2nx8m(a01,b10,cs_b,p_lda,k_iter)\ for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ {\ /*load 8x1 block of B10*/\ ymm0 = _mm256_loadu_pd((double const *)b10);/*B10[0][0] B10[1][0] B10[2][0] B10[3][0]*/\ + ymm1 = _mm256_loadu_pd((double const *)(b10 + 4));/*B10[4][0] B10[5][0] B10[6][0] B10[7][0]*/\ \ /*broadcast 1st row of A01*/\ ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3);\ + ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4);\ \ ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); /*A01[0][1]*/\ ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5);\ + ymm6 = _mm256_fmadd_pd(ymm2, ymm1, ymm6);\ \ a01 += 1; /*move to next row*/\ b10 += cs_b;\ } -#define BLIS_DTRSM_SMALL_GEMM_1nx4m(a01,b10,cs_b,p_lda,k_iter) \ +#define BLIS_DTRSM_SMALL_GEMM_2nx4m(a01,b10,cs_b,p_lda,k_iter) \ + for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + /*load 4x1 block of B10*/\ + ymm0 = _mm256_loadu_pd((double const *)b10);/*B10[0][0] B10[1][0] B10[2][0] B10[3][0]*/\ +\ + /*broadcast 1st row of A01*/\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ + ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3);\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); /*A01[0][1]*/\ + ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5);\ +\ + a01 += 1; /*move to next row*/\ + b10 += cs_b;\ + } + +#define BLIS_DTRSM_SMALL_GEMM_2nx3m(a01,b10,cs_b,p_lda,k_iter) \ + for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + /*load 3x1 block of B10*/\ + xmm5 = _mm_loadu_pd((double const*)(b10));\ + ymm0 = _mm256_broadcast_sd((double const *)(b10 + 2));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ +\ + /*broadcast 1st row of A01*/\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ + ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3);\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); /*A01[0][1]*/\ + ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5);\ +\ + a01 += 1; /*move to next row*/\ + b10 += cs_b;\ + } + +#define BLIS_DTRSM_SMALL_GEMM_2nx2m(a01,b10,cs_b,p_lda,k_iter) \ + for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + /*load 2x1 block of B10*/\ + xmm5 = _mm_loadu_pd((double const*)(b10));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ +\ + /*broadcast 1st row of A01*/\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ + ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3);\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); /*A01[0][1]*/\ + ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5);\ +\ + a01 += 1; /*move to next row*/\ + b10 += cs_b;\ + } + +#define BLIS_DTRSM_SMALL_GEMM_2nx1m(a01,b10,cs_b,p_lda,k_iter) \ + for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + /*load 1x1 block of B10*/\ + ymm0 = _mm256_broadcast_sd((double const *)b10);\ +\ + /*broadcast 1st row of A01*/\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ + ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3);\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); /*A01[0][1]*/\ + ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5);\ +\ + a01 += 1; /*move to next row*/\ + b10 += cs_b;\ + } + +#define BLIS_DTRSM_SMALL_GEMM_1nx8m(a01,b10,cs_b,p_lda,k_iter)\ for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ {\ /*load 8x1 block of B10*/\ ymm0 = _mm256_loadu_pd((double const *)b10);/*B10[0][0] B10[1][0] B10[2][0] B10[3][0]*/\ + ymm1 = _mm256_loadu_pd((double const *)(b10 + 4));/*B10[4][0] B10[5][0] B10[6][0] B10[7][0]*/\ +\ + /*broadcast 1st row of A01*/\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ + ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3);\ + ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4);\ +\ + a01 += 1; /*move to next row*/\ + b10 += cs_b;\ + } + +#define BLIS_DTRSM_SMALL_GEMM_1nx4m(a01,b10,cs_b,p_lda,k_iter) \ + for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + /*load 4x1 block of B10*/\ + ymm0 = _mm256_loadu_pd((double const *)b10);/*B10[0][0] B10[1][0] B10[2][0] B10[3][0]*/\ +\ + /*broadcast 1st row of A01*/\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ + ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3);\ +\ + a01 += 1; /*move to next row*/\ + b10 += cs_b;\ + } + +#define BLIS_DTRSM_SMALL_GEMM_1nx3m(a01,b10,cs_b,p_lda,k_iter) \ + for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + /*load 3x1 block of B10*/\ + xmm5 = _mm_loadu_pd((double const*)(b10));\ + ymm0 = _mm256_broadcast_sd((double const *)(b10+ 2));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ +\ + /*broadcast 1st row of A01*/\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ + ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3);\ +\ + a01 += 1; /*move to next row*/\ + b10 += cs_b;\ + } + +#define BLIS_DTRSM_SMALL_GEMM_1nx2m(a01,b10,cs_b,p_lda,k_iter) \ + for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + /*load 2x1 block of B10*/\ + xmm5 = _mm_loadu_pd((double const*)(b10));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ +\ + /*broadcast 1st row of A01*/\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ + ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3);\ +\ + a01 += 1; /*move to next row*/\ + b10 += cs_b;\ + } + +#define BLIS_DTRSM_SMALL_GEMM_1nx1m(a01,b10,cs_b,p_lda,k_iter) \ + for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + /*load 1x1 block of B10*/\ + ymm0 = _mm256_broadcast_sd((double const *)b10);\ \ /*broadcast 1st row of A01*/\ ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ @@ -1262,8 +1588,12 @@ BLIS_INLINE err_t dtrsm_XAltB_ref #define BLIS_PRE_DTRSM_SMALL_3M_3N(AlphaVal,b11,cs_b)\ ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); /*register to hold alpha*/\ \ - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0));\ - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1));\ + xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 0));\ + ymm0 = _mm256_broadcast_sd((double const*)(b11 + cs_b * 0 + 2));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ + xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 1));\ + ymm1 = _mm256_broadcast_sd((double const*)(b11 + cs_b * 1 + 2));\ + ymm1 = _mm256_insertf128_pd(ymm1, xmm5, 0);\ xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 2));\ ymm2 = _mm256_broadcast_sd((double const*)(b11 + cs_b * 2 + 2));\ ymm2 = _mm256_insertf128_pd(ymm2, xmm5, 0);\ @@ -1272,20 +1602,22 @@ BLIS_INLINE err_t dtrsm_XAltB_ref ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9);\ ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10);\ \ - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08);\ - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x08);\ - ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x08);\ -\ - _mm256_storeu_pd((double *)(b11), ymm0); /*store(B11[0-3][0])*/\ - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); /*store(B11[0-3][1])*/\ - xmm5 = _mm256_extractf128_pd(ymm2, 0);\ + xmm5 = _mm256_castpd256_pd128(ymm8);\ + _mm_storeu_pd((double *)(b11 + cs_b * 0), xmm5);\ + _mm_storel_pd((b11 + cs_b * 0 + 2), _mm256_extractf128_pd(ymm8, 1));\ + xmm5 = _mm256_castpd256_pd128(ymm9);\ + _mm_storeu_pd((double *)(b11 + cs_b * 1), xmm5);\ + _mm_storel_pd((b11 + cs_b * 1 + 2), _mm256_extractf128_pd(ymm9, 1));\ + xmm5 = _mm256_castpd256_pd128(ymm10);\ _mm_storeu_pd((double *)(b11 + cs_b * 2), xmm5);\ - _mm_storel_pd((b11 + cs_b * 2 + 2), _mm256_extractf128_pd(ymm2, 1)); + _mm_storel_pd((b11 + cs_b * 2 + 2), _mm256_extractf128_pd(ymm10, 1)); #define BLIS_PRE_DTRSM_SMALL_3M_2N(AlphaVal,b11,cs_b)\ ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); /*register to hold alpha*/\ \ - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0));\ + xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 0));\ + ymm0 = _mm256_broadcast_sd((double const*)(b11 + cs_b * 0 + 2));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 1));\ ymm1 = _mm256_broadcast_sd((double const*)(b11 + cs_b * 1 + 2));\ ymm1 = _mm256_insertf128_pd(ymm1, xmm5, 0);\ @@ -1293,13 +1625,12 @@ BLIS_INLINE err_t dtrsm_XAltB_ref ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8);\ ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9);\ \ - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08);\ - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x08);\ -\ - _mm256_storeu_pd((double *)(b11), ymm0); /*store(B11[0-3][0])*/\ - xmm5 = _mm256_extractf128_pd(ymm1, 0);\ + xmm5 = _mm256_castpd256_pd128(ymm8);\ + _mm_storeu_pd((double *)(b11 + cs_b * 0), xmm5);\ + _mm_storel_pd((b11 + cs_b * 0 + 2), _mm256_extractf128_pd(ymm8, 1));\ + xmm5 = _mm256_castpd256_pd128(ymm9);\ _mm_storeu_pd((double *)(b11 + cs_b * 1), xmm5);\ - _mm_storel_pd((b11 + cs_b * 1 + 2), _mm256_extractf128_pd(ymm1, 1)); + _mm_storel_pd((b11 + cs_b * 1 + 2), _mm256_extractf128_pd(ymm9, 1)); #define BLIS_PRE_DTRSM_SMALL_3M_1N(AlphaVal,b11,cs_b)\ ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); /*register to hold alpha*/\ @@ -1308,18 +1639,19 @@ BLIS_INLINE err_t dtrsm_XAltB_ref ymm0 = _mm256_broadcast_sd((double const*)(b11 + cs_b * 0 + 2));\ ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8);\ - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08);\ \ - xmm5 = _mm256_extractf128_pd(ymm0, 0);\ + xmm5 = _mm256_castpd256_pd128(ymm8);\ _mm_storeu_pd((double *)(b11), xmm5);\ - _mm_storel_pd((b11 + 2), _mm256_extractf128_pd(ymm0, 1)); + _mm_storel_pd((b11 + 2), _mm256_extractf128_pd(ymm8, 1)); #define BLIS_PRE_DTRSM_SMALL_2M_3N(AlphaVal,b11,cs_b)\ ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); /*register to hold alpha*/\ \ - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0));\ - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1));\ + xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 0));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ + xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 1));\ + ymm1 = _mm256_insertf128_pd(ymm1, xmm5, 0);\ xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 2));\ ymm2 = _mm256_insertf128_pd(ymm2, xmm5, 0);\ \ @@ -1327,30 +1659,27 @@ BLIS_INLINE err_t dtrsm_XAltB_ref ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9);\ ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10);\ \ - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0C);\ - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0C);\ - ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0C);\ -\ - _mm256_storeu_pd((double *)(b11), ymm0); /*store(B11[0-3][0])*/\ - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); /*store(B11[0-3][1])*/\ - xmm5 = _mm256_extractf128_pd(ymm2, 0);\ + xmm5 = _mm256_castpd256_pd128(ymm8);\ + _mm_storeu_pd((double *)(b11 + cs_b * 0), xmm5);\ + xmm5 = _mm256_castpd256_pd128(ymm9);\ + _mm_storeu_pd((double *)(b11 + cs_b * 1), xmm5);\ + xmm5 = _mm256_castpd256_pd128(ymm10);\ _mm_storeu_pd((double *)(b11 + cs_b * 2), xmm5); #define BLIS_PRE_DTRSM_SMALL_2M_2N(AlphaVal,b11,cs_b)\ ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); /*register to hold alpha*/\ \ - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0));\ + xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 0));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 1));\ ymm1 = _mm256_insertf128_pd(ymm1, xmm5, 0);\ \ ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8);\ ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9);\ \ - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0C);\ - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0C);\ -\ - _mm256_storeu_pd((double *)(b11), ymm0); /*store(B11[0-3][0])*/\ - xmm5 = _mm256_extractf128_pd(ymm1, 0);\ + xmm5 = _mm256_castpd256_pd128(ymm8);\ + _mm_storeu_pd((double *)(b11 + cs_b * 0), xmm5);\ + xmm5 = _mm256_castpd256_pd128(ymm9);\ _mm_storeu_pd((double *)(b11 + cs_b * 1), xmm5); #define BLIS_PRE_DTRSM_SMALL_2M_1N(AlphaVal,b11,cs_b)\ @@ -1360,9 +1689,8 @@ BLIS_INLINE err_t dtrsm_XAltB_ref ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ \ ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8);\ - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0C);\ \ - xmm5 = _mm256_extractf128_pd(ymm0, 0);\ + xmm5 = _mm256_castpd256_pd128(ymm8);\ _mm_storeu_pd((double *)(b11 + cs_b * 0), xmm5); #define BLIS_PRE_DTRSM_SMALL_1M_3N(AlphaVal,b11,cs_b)\ @@ -1376,13 +1704,9 @@ BLIS_INLINE err_t dtrsm_XAltB_ref ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9);\ ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10);\ \ - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E);\ - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0E);\ - ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0E);\ -\ - _mm_storel_pd((b11 + cs_b * 0), _mm256_extractf128_pd(ymm0, 0));\ - _mm_storel_pd((b11 + cs_b * 1), _mm256_extractf128_pd(ymm1, 0));\ - _mm_storel_pd((b11 + cs_b * 2), _mm256_extractf128_pd(ymm2, 0)); + _mm_storel_pd((double *)(b11), _mm256_extractf128_pd(ymm8,0));\ + _mm_storel_pd((double *)(b11 + cs_b * 1), _mm256_extractf128_pd(ymm9,0));\ + _mm_storel_pd((double *)(b11 + cs_b * 2), _mm256_extractf128_pd(ymm10,0)); #define BLIS_PRE_DTRSM_SMALL_1M_2N(AlphaVal,b11,cs_b)\ ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); /*register to hold alpha*/\ @@ -1393,11 +1717,8 @@ BLIS_INLINE err_t dtrsm_XAltB_ref ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8);\ ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9);\ \ - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E);\ - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0E);\ -\ - _mm_storel_pd((b11 + cs_b * 0), _mm256_extractf128_pd(ymm0, 0));\ - _mm_storel_pd((b11 + cs_b * 1), _mm256_extractf128_pd(ymm1, 0)); + _mm_storel_pd((double *)(b11), _mm256_extractf128_pd(ymm8,0));\ + _mm_storel_pd((double *)(b11 + cs_b * 1), _mm256_extractf128_pd(ymm9,0)); #define BLIS_PRE_DTRSM_SMALL_1M_1N(AlphaVal,b11,cs_b)\ ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); /*register to hold alpha*/\ @@ -1405,18 +1726,20 @@ BLIS_INLINE err_t dtrsm_XAltB_ref ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b *0));\ ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8);\ \ - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E);\ -\ - _mm_storel_pd((b11 + cs_b * 0), _mm256_extractf128_pd(ymm0, 0)); + _mm_storel_pd((double *)(b11), _mm256_extractf128_pd(ymm8,0)); /* pre & post TRSM for Right remainder cases*/ #define BLIS_PRE_DTRSM_SMALL_3N_3M(AlphaVal,b11,cs_b)\ ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); /*register to hold alpha*/\ \ - ymm0 = _mm256_loadu_pd((double const *)b11);\ + xmm5 = _mm_loadu_pd((double const*)(b11));\ + ymm0 = _mm256_broadcast_sd((double const *)(b11 + 2));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3);\ \ - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b));\ + xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b));\ + ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b + 2));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5);\ \ xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 2));\ @@ -1425,28 +1748,26 @@ BLIS_INLINE err_t dtrsm_XAltB_ref ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); #define BLIS_POST_DTRSM_SMALL_3N_3M(b11,cs_b)\ - ymm0 = _mm256_loadu_pd((double const *)b11);\ - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x07);\ - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b));\ - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x07);\ - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 2));\ - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*2 + 2));\ - ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ - ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x07);\ \ - _mm256_storeu_pd((double *)b11, ymm3);\ - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5);\ - xmm5 = _mm256_extractf128_pd(ymm7, 0);\ + xmm5 = _mm256_castpd256_pd128(ymm3);\ + _mm_storeu_pd((double *)(b11),xmm5);\ + _mm_storel_pd((b11 + 2), _mm256_extractf128_pd(ymm3, 1));\ + xmm5 = _mm256_castpd256_pd128(ymm5);\ + _mm_storeu_pd((double *)(b11 + cs_b),xmm5);\ + _mm_storel_pd((b11 + cs_b + 2), _mm256_extractf128_pd(ymm5, 1));\ + xmm5 = _mm256_castpd256_pd128(ymm7);\ _mm_storeu_pd((double *)(b11 + cs_b * 2),xmm5);\ _mm_storel_pd((b11 + cs_b * 2 + 2), _mm256_extractf128_pd(ymm7, 1)); #define BLIS_PRE_DTRSM_SMALL_3N_2M(AlphaVal,b11,cs_b)\ ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); /*register to hold alpha*/\ \ - ymm0 = _mm256_loadu_pd((double const *)b11);\ + xmm5 = _mm_loadu_pd((double const*)(b11));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3);\ \ - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b));\ + xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5);\ \ xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 2));\ @@ -1454,17 +1775,12 @@ BLIS_INLINE err_t dtrsm_XAltB_ref ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); #define BLIS_POST_DTRSM_SMALL_3N_2M(b11,cs_b)\ - ymm0 = _mm256_loadu_pd((double const *)b11);\ - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x03);\ - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b));\ - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x03);\ - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 2));\ - ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ - ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x03);\ \ - _mm256_storeu_pd((double *)b11, ymm3);\ - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5);\ - xmm5 = _mm256_extractf128_pd(ymm7, 0);\ + xmm5 = _mm256_castpd256_pd128(ymm3);\ + _mm_storeu_pd((double *)(b11),xmm5);\ + xmm5 = _mm256_castpd256_pd128(ymm5);\ + _mm_storeu_pd((double *)(b11 + cs_b),xmm5);\ + xmm5 = _mm256_castpd256_pd128(ymm7);\ _mm_storeu_pd((double *)(b11 + cs_b * 2),xmm5); #define BLIS_PRE_DTRSM_SMALL_3N_1M(AlphaVal,b11,cs_b)\ @@ -1480,21 +1796,17 @@ BLIS_INLINE err_t dtrsm_XAltB_ref ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); #define BLIS_POST_DTRSM_SMALL_3N_1M(b11,cs_b)\ - ymm0 = _mm256_broadcast_sd((double const *)b11);\ - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x01);\ - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b));\ - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x01);\ - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*2));\ - ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x01);\ \ - _mm_storel_pd((b11 + cs_b * 0), _mm256_extractf128_pd(ymm3, 0));\ - _mm_storel_pd((b11 + cs_b * 1), _mm256_extractf128_pd(ymm5, 0));\ - _mm_storel_pd((b11 + cs_b * 2), _mm256_extractf128_pd(ymm7, 0)); + _mm_storel_pd((b11 + cs_b * 0), _mm256_castpd256_pd128(ymm3));\ + _mm_storel_pd((b11 + cs_b * 1), _mm256_castpd256_pd128(ymm5));\ + _mm_storel_pd((b11 + cs_b * 2), _mm256_castpd256_pd128(ymm7)); #define BLIS_PRE_DTRSM_SMALL_2N_3M(AlphaVal,b11,cs_b)\ ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); /*register to hold alpha*/\ \ - ymm0 = _mm256_loadu_pd((double const *)b11);\ + xmm5 = _mm_loadu_pd((double const*)(b11));\ + ymm0 = _mm256_broadcast_sd((double const *)(b11 + 2));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3);\ \ xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 1));\ @@ -1503,22 +1815,19 @@ BLIS_INLINE err_t dtrsm_XAltB_ref ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); #define BLIS_POST_DTRSM_SMALL_2N_3M(b11,cs_b)\ - ymm0 = _mm256_loadu_pd((double const *)b11);\ - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x07);\ - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 1));\ - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*1 + 2));\ - ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x07);\ \ - _mm256_storeu_pd((double *)b11, ymm3);\ - xmm5 = _mm256_extractf128_pd(ymm5, 0);\ + xmm5 = _mm256_castpd256_pd128(ymm3);\ + _mm_storeu_pd((double *)(b11), xmm5);\ + _mm_storel_pd((b11 + 2), _mm256_extractf128_pd(ymm3, 1));\ + xmm5 = _mm256_castpd256_pd128(ymm5);\ _mm_storeu_pd((double *)(b11 + cs_b*1), xmm5);\ _mm_storel_pd((b11 + cs_b * 1 + 2), _mm256_extractf128_pd(ymm5, 1)); #define BLIS_PRE_DTRSM_SMALL_2N_2M(AlphaVal,b11,cs_b)\ ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); /*register to hold alpha*/\ \ - ymm0 = _mm256_loadu_pd((double const *)b11);\ + xmm5 = _mm_loadu_pd((double const*)(b11));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3);\ \ xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 1));\ @@ -1526,14 +1835,10 @@ BLIS_INLINE err_t dtrsm_XAltB_ref ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); #define BLIS_POST_DTRSM_SMALL_2N_2M(b11,cs_b)\ - ymm0 = _mm256_loadu_pd((double const *)b11);\ - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x03);\ - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 1));\ - ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x03);\ \ - _mm256_storeu_pd((double *)b11, ymm3);\ - xmm5 = _mm256_extractf128_pd(ymm5, 0);\ + xmm5 = _mm256_castpd256_pd128(ymm3);\ + _mm_storeu_pd((double *)(b11), xmm5);\ + xmm5 = _mm256_castpd256_pd128(ymm5);\ _mm_storeu_pd((double *)(b11 + cs_b*1), xmm5); #define BLIS_PRE_DTRSM_SMALL_2N_1M(AlphaVal,b11,cs_b)\ @@ -1546,24 +1851,20 @@ BLIS_INLINE err_t dtrsm_XAltB_ref ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); #define BLIS_POST_DTRSM_SMALL_2N_1M(b11,cs_b)\ - ymm0 = _mm256_broadcast_sd((double const *)b11);\ - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x01);\ - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b));\ - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x01);\ \ - _mm_storel_pd(b11 , _mm256_extractf128_pd(ymm3, 0));\ - _mm_storel_pd((b11 + cs_b * 1), _mm256_extractf128_pd(ymm5, 0)); + _mm_storel_pd(b11 , _mm256_castpd256_pd128(ymm3));\ + _mm_storel_pd((b11 + cs_b * 1), _mm256_castpd256_pd128(ymm5)); #define BLIS_PRE_DTRSM_SMALL_1N_3M(AlphaVal,b11,cs_b)\ ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); /*register to hold alpha*/\ \ xmm5 = _mm_loadu_pd((double const*)(b11));\ ymm0 = _mm256_broadcast_sd((double const *)(b11+ 2));\ - ymm6 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ - ymm3 = _mm256_fmsub_pd(ymm6, ymm15, ymm3); + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ + ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); #define BLIS_POST_DTRSM_SMALL_1N_3M(b11,cs_b)\ - xmm5 = _mm256_extractf128_pd(ymm3, 0);\ + xmm5 = _mm256_castpd256_pd128(ymm3);\ _mm_storeu_pd((double *)(b11), xmm5);\ _mm_storel_pd((b11 + 2), _mm256_extractf128_pd(ymm3, 1)); @@ -1571,26 +1872,23 @@ BLIS_INLINE err_t dtrsm_XAltB_ref ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); /*register to hold alpha*/\ \ xmm5 = _mm_loadu_pd((double const*)(b11));\ - ymm6 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ - ymm3 = _mm256_fmsub_pd(ymm6, ymm15, ymm3); + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ + ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); #define BLIS_POST_DTRSM_SMALL_1N_2M(b11,cs_b)\ - ymm0 = _mm256_loadu_pd((double const *)b11);\ - ymm3 = _mm256_blend_pd(ymm6, ymm3, 0x03);\ \ - xmm5 = _mm256_extractf128_pd(ymm3, 0);\ + xmm5 = _mm256_castpd256_pd128(ymm3);\ _mm_storeu_pd((double *)(b11), xmm5); #define BLIS_PRE_DTRSM_SMALL_1N_1M(AlphaVal,b11,cs_b)\ ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); /*register to hold alpha*/\ \ - ymm6 = _mm256_broadcast_sd((double const *)b11);\ - ymm3 = _mm256_fmsub_pd(ymm6, ymm15, ymm3); + ymm0 = _mm256_broadcast_sd((double const *)b11);\ + ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); #define BLIS_POST_DTRSM_SMALL_1N_1M(b11,cs_b)\ - ymm3 = _mm256_blend_pd(ymm6, ymm3, 0x01);\ \ - _mm_storel_pd(b11, _mm256_extractf128_pd(ymm3, 0)); + _mm_storel_pd(b11, _mm256_castpd256_pd128(ymm3)); /* multiply with Alpha pre TRSM for 6*8 kernel*/ #define BLIS_PRE_DTRSM_SMALL_6x8(AlphaVal,b11,cs_b)\ @@ -1680,6 +1978,99 @@ BLIS_INLINE err_t dtrsm_XAltB_ref ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*5));\ ymm13 = _mm256_fmsub_pd(ymm0, ymm15, ymm13); +#define BLIS_PRE_DTRSM_SMALL_6x3(AlphaVal,b11,cs_b)\ + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); /*register to hold alpha*/\ +\ + ymm0 = _mm256_broadcast_sd ((double const *)(b11 + 2));\ + xmm5 = _mm_loadu_pd((double const *)(b11));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ +\ + ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3);\ +\ + ymm0 = _mm256_broadcast_sd ((double const *)(b11 + 2 + cs_b));\ + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ +\ + ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5);\ +\ + ymm0 = _mm256_broadcast_sd ((double const *)(b11 + 2 + cs_b*2));\ + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b*2));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ +\ + ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7);\ +\ + ymm0 = _mm256_broadcast_sd ((double const *)(b11 + 2 + cs_b*3));\ + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b*3));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ +\ + ymm9 = _mm256_fmsub_pd(ymm0, ymm15, ymm9);\ +\ + ymm0 = _mm256_broadcast_sd ((double const *)(b11 + 2 + cs_b*4));\ + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b*4));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ +\ + ymm11 = _mm256_fmsub_pd(ymm0, ymm15, ymm11);\ +\ + ymm0 = _mm256_broadcast_sd ((double const *)(b11 + 2 + cs_b*5));\ + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b*5));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ +\ + ymm13 = _mm256_fmsub_pd(ymm0, ymm15, ymm13); + +#define BLIS_PRE_DTRSM_SMALL_6x2(AlphaVal,b11,cs_b)\ + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); /*register to hold alpha*/\ +\ + xmm5 = _mm_loadu_pd((double const *)(b11));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ +\ + ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3);\ +\ + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ +\ + ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5);\ +\ + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b*2));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ +\ + ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7);\ +\ + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b*3));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ +\ + ymm9 = _mm256_fmsub_pd(ymm0, ymm15, ymm9);\ +\ + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b*4));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ +\ + ymm11 = _mm256_fmsub_pd(ymm0, ymm15, ymm11);\ +\ + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b*5));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ +\ + ymm13 = _mm256_fmsub_pd(ymm0, ymm15, ymm13); + +#define BLIS_PRE_DTRSM_SMALL_6x1(AlphaVal,b11,cs_b)\ + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); /*register to hold alpha*/\ +\ + ymm0 = _mm256_broadcast_sd ((double const *)(b11));\ + ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3);\ +\ + ymm0 = _mm256_broadcast_sd ((double const *)(b11 + cs_b));\ + ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5);\ +\ + ymm0 = _mm256_broadcast_sd ((double const *)(b11 + cs_b*2));\ + ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7);\ +\ + ymm0 = _mm256_broadcast_sd ((double const *)(b11 + cs_b*3));\ + ymm9 = _mm256_fmsub_pd(ymm0, ymm15, ymm9);\ +\ + ymm0 = _mm256_broadcast_sd ((double const *)(b11 + cs_b*4));\ + ymm11 = _mm256_fmsub_pd(ymm0, ymm15, ymm11);\ +\ + ymm0 = _mm256_broadcast_sd ((double const *)(b11 + cs_b*5));\ + ymm13 = _mm256_fmsub_pd(ymm0, ymm15, ymm13); + #ifdef BLIS_DISABLE_TRSM_PREINVERSION #define STRSM_SMALL_DIV_OR_SCALE _mm256_div_ps #endif @@ -1772,6 +2163,223 @@ BLIS_INLINE err_t dtrsm_XAltB_ref b10 += cs_b;\ } +#define BLIS_STRSM_SMALL_GEMM_6nx7m(a01,b10,cs_b,p_lda,k_iter) \ + for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + /*load 7x1 block of B10*/\ + xmm5 = _mm_loadu_ps((float const*)(b10));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ + xmm5 = _mm_broadcast_ss((float const*)(b10 + 6));\ + xmm5 = _mm_loadl_pi(xmm5,(__m64*)(b10 + 4));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 1);\ +\ + /*broadcast 1st row of A01*/\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ + ymm3 = _mm256_fmadd_ps(ymm2, ymm0, ymm3);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 1)); /*A01[0][1]*/\ + ymm5 = _mm256_fmadd_ps(ymm2, ymm0, ymm5);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 2)); /*A01[0][2]*/\ + ymm7 = _mm256_fmadd_ps(ymm2, ymm0, ymm7);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 3)); /*A01[0][3]*/\ + ymm9 = _mm256_fmadd_ps(ymm2, ymm0, ymm9);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 4)); /*A01[0][4]*/\ + ymm11 = _mm256_fmadd_ps(ymm2, ymm0, ymm11);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 5)); /*A01[0][5]*/\ + ymm13 = _mm256_fmadd_ps(ymm2, ymm0, ymm13);\ +\ + a01 += 1; /*move to next row*/\ + b10 += cs_b;\ + } + +#define BLIS_STRSM_SMALL_GEMM_6nx6m(a01,b10,cs_b,p_lda,k_iter) \ + for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + /*load 6x1 block of B10*/\ + xmm5 = _mm_loadu_ps((float const*)(b10));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ + xmm5 = _mm_loadl_pi(xmm5,(__m64*)(b10 + 4));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 1);\ +\ + /*broadcast 1st row of A01*/\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ + ymm3 = _mm256_fmadd_ps(ymm2, ymm0, ymm3);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 1)); /*A01[0][1]*/\ + ymm5 = _mm256_fmadd_ps(ymm2, ymm0, ymm5);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 2)); /*A01[0][2]*/\ + ymm7 = _mm256_fmadd_ps(ymm2, ymm0, ymm7);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 3)); /*A01[0][3]*/\ + ymm9 = _mm256_fmadd_ps(ymm2, ymm0, ymm9);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 4)); /*A01[0][4]*/\ + ymm11 = _mm256_fmadd_ps(ymm2, ymm0, ymm11);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 5)); /*A01[0][5]*/\ + ymm13 = _mm256_fmadd_ps(ymm2, ymm0, ymm13);\ +\ + a01 += 1; /*move to next row*/\ + b10 += cs_b;\ + } + +#define BLIS_STRSM_SMALL_GEMM_6nx5m(a01,b10,cs_b,p_lda,k_iter) \ + for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + /*load 5x1 block of B10*/\ + ymm0 = _mm256_broadcast_ss((float const *)(b10 + 4));\ + xmm5 = _mm_loadu_ps((float const*)(b10));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + /*broadcast 1st row of A01*/\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ + ymm3 = _mm256_fmadd_ps(ymm2, ymm0, ymm3);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 1)); /*A01[0][1]*/\ + ymm5 = _mm256_fmadd_ps(ymm2, ymm0, ymm5);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 2)); /*A01[0][2]*/\ + ymm7 = _mm256_fmadd_ps(ymm2, ymm0, ymm7);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 3)); /*A01[0][3]*/\ + ymm9 = _mm256_fmadd_ps(ymm2, ymm0, ymm9);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 4)); /*A01[0][4]*/\ + ymm11 = _mm256_fmadd_ps(ymm2, ymm0, ymm11);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 5)); /*A01[0][5]*/\ + ymm13 = _mm256_fmadd_ps(ymm2, ymm0, ymm13);\ +\ + a01 += 1; /*move to next row*/\ + b10 += cs_b;\ + } + +#define BLIS_STRSM_SMALL_GEMM_6nx4m(a01,b10,cs_b,p_lda,k_iter) \ + for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + /*load 4x1 block of B10*/\ + xmm5 = _mm_loadu_ps((float const*)(b10));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + /*broadcast 1st row of A01*/\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ + ymm3 = _mm256_fmadd_ps(ymm2, ymm0, ymm3);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 1)); /*A01[0][1]*/\ + ymm5 = _mm256_fmadd_ps(ymm2, ymm0, ymm5);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 2)); /*A01[0][2]*/\ + ymm7 = _mm256_fmadd_ps(ymm2, ymm0, ymm7);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 3)); /*A01[0][3]*/\ + ymm9 = _mm256_fmadd_ps(ymm2, ymm0, ymm9);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 4)); /*A01[0][4]*/\ + ymm11 = _mm256_fmadd_ps(ymm2, ymm0, ymm11);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 5)); /*A01[0][5]*/\ + ymm13 = _mm256_fmadd_ps(ymm2, ymm0, ymm13);\ +\ + a01 += 1; /*move to next row*/\ + b10 += cs_b;\ + } + +#define BLIS_STRSM_SMALL_GEMM_6nx3m(a01,b10,cs_b,p_lda,k_iter) \ + for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + /*load 3x1 block of B10*/\ + __m128 xmm6 = _mm_broadcast_ss((float const *)(b10+ 2));\ + xmm5 = _mm_loadl_pi(xmm6,(__m64*)(b10)); \ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + /*broadcast 1st row of A01*/\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ + ymm3 = _mm256_fmadd_ps(ymm2, ymm0, ymm3);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 1)); /*A01[0][1]*/\ + ymm5 = _mm256_fmadd_ps(ymm2, ymm0, ymm5);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 2)); /*A01[0][2]*/\ + ymm7 = _mm256_fmadd_ps(ymm2, ymm0, ymm7);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 3)); /*A01[0][3]*/\ + ymm9 = _mm256_fmadd_ps(ymm2, ymm0, ymm9);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 4)); /*A01[0][4]*/\ + ymm11 = _mm256_fmadd_ps(ymm2, ymm0, ymm11);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 5)); /*A01[0][5]*/\ + ymm13 = _mm256_fmadd_ps(ymm2, ymm0, ymm13);\ +\ + a01 += 1; /*move to next row*/\ + b10 += cs_b;\ + } + +#define BLIS_STRSM_SMALL_GEMM_6nx2m(a01,b10,cs_b,p_lda,k_iter) \ + for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + /*load 2x1 block of B10*/\ + xmm5 = _mm_setzero_ps();\ + xmm5 = _mm_loadl_pi(xmm5,(__m64*)(b10));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + /*broadcast 1st row of A01*/\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ + ymm3 = _mm256_fmadd_ps(ymm2, ymm0, ymm3);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 1)); /*A01[0][1]*/\ + ymm5 = _mm256_fmadd_ps(ymm2, ymm0, ymm5);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 2)); /*A01[0][2]*/\ + ymm7 = _mm256_fmadd_ps(ymm2, ymm0, ymm7);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 3)); /*A01[0][3]*/\ + ymm9 = _mm256_fmadd_ps(ymm2, ymm0, ymm9);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 4)); /*A01[0][4]*/\ + ymm11 = _mm256_fmadd_ps(ymm2, ymm0, ymm11);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 5)); /*A01[0][5]*/\ + ymm13 = _mm256_fmadd_ps(ymm2, ymm0, ymm13);\ +\ + a01 += 1; /*move to next row*/\ + b10 += cs_b;\ + } + +#define BLIS_STRSM_SMALL_GEMM_6nx1m(a01,b10,cs_b,p_lda,k_iter) \ + for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + /*load 1x1 block of B10*/\ + ymm0 = _mm256_broadcast_ss((float const *)b10);\ +\ + /*broadcast 1st row of A01*/\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ + ymm3 = _mm256_fmadd_ps(ymm2, ymm0, ymm3);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 1)); /*A01[0][1]*/\ + ymm5 = _mm256_fmadd_ps(ymm2, ymm0, ymm5);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 2)); /*A01[0][2]*/\ + ymm7 = _mm256_fmadd_ps(ymm2, ymm0, ymm7);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 3)); /*A01[0][3]*/\ + ymm9 = _mm256_fmadd_ps(ymm2, ymm0, ymm9);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 4)); /*A01[0][4]*/\ + ymm11 = _mm256_fmadd_ps(ymm2, ymm0, ymm11);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 5)); /*A01[0][5]*/\ + ymm13 = _mm256_fmadd_ps(ymm2, ymm0, ymm13);\ +\ + a01 += 1; /*move to next row*/\ + b10 += cs_b;\ + } + #define BLIS_STRSM_SMALL_GEMM_4nx16m(a01,b10,cs_b,p_lda,k_iter) \ for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ {\ @@ -1883,11 +2491,15 @@ BLIS_INLINE err_t dtrsm_XAltB_ref b10 += cs_b;\ } -#define BLIS_STRSM_SMALL_GEMM_3nx8m(a01,b10,cs_b,p_lda,k_iter) \ +#define BLIS_STRSM_SMALL_GEMM_4nx7m(a01,b10,cs_b,p_lda,k_iter) \ for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ {\ - /*load 8x1 block of B10*/\ - ymm0 = _mm256_loadu_ps((float const *)b10);/*B10[0][0] B10[1][0] B10[2][0] B10[3][0]*/\ + /*load 7x1 block of B10*/\ + xmm5 = _mm_loadu_ps((float const*)(b10));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ + xmm5 = _mm_broadcast_ss((float const*)(b10 + 6));\ + xmm5 = _mm_loadl_pi(xmm5,(__m64*)(b10 + 4));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 1);\ \ /*broadcast 1st row of A01*/\ ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ @@ -1898,16 +2510,22 @@ BLIS_INLINE err_t dtrsm_XAltB_ref \ ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 2)); /*A01[0][2]*/\ ymm7 = _mm256_fmadd_ps(ymm2, ymm0, ymm7);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 3)); /*A01[0][3]*/\ + ymm9 = _mm256_fmadd_ps(ymm2, ymm0, ymm9);\ \ a01 += 1; /*move to next row*/\ b10 += cs_b;\ } -#define BLIS_STRSM_SMALL_GEMM_2nx8m(a01,b10,cs_b,p_lda,k_iter) \ +#define BLIS_STRSM_SMALL_GEMM_4nx6m(a01,b10,cs_b,p_lda,k_iter) \ for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ {\ - /*load 8x1 block of B10*/\ - ymm0 = _mm256_loadu_ps((float const *)b10);/*B10[0][0] B10[1][0] B10[2][0] B10[3][0]*/\ + /*load 6x1 block of B10*/\ + xmm5 = _mm_loadu_ps((float const*)(b10));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ + xmm5 = _mm_loadl_pi(xmm5,(__m64*)(b10 + 4));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 1);\ \ /*broadcast 1st row of A01*/\ ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ @@ -1915,222 +2533,786 @@ BLIS_INLINE err_t dtrsm_XAltB_ref \ ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 1)); /*A01[0][1]*/\ ymm5 = _mm256_fmadd_ps(ymm2, ymm0, ymm5);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 2)); /*A01[0][2]*/\ + ymm7 = _mm256_fmadd_ps(ymm2, ymm0, ymm7);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 3)); /*A01[0][3]*/\ + ymm9 = _mm256_fmadd_ps(ymm2, ymm0, ymm9);\ \ a01 += 1; /*move to next row*/\ b10 += cs_b;\ } -#define BLIS_STRSM_SMALL_GEMM_1nx8m(a01,b10,cs_b,p_lda,k_iter) \ +#define BLIS_STRSM_SMALL_GEMM_4nx5m(a01,b10,cs_b,p_lda,k_iter) \ for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ {\ - /*load 8x1 block of B10*/\ - ymm0 = _mm256_loadu_ps((float const *)b10);/*B10[0][0] B10[1][0] B10[2][0] B10[3][0]*/\ + /*load 5x1 block of B10*/\ + ymm0 = _mm256_broadcast_ss((float const *)(b10 + 4));\ + xmm5 = _mm_loadu_ps((float const*)(b10));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ \ /*broadcast 1st row of A01*/\ ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ ymm3 = _mm256_fmadd_ps(ymm2, ymm0, ymm3);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 1)); /*A01[0][1]*/\ + ymm5 = _mm256_fmadd_ps(ymm2, ymm0, ymm5);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 2)); /*A01[0][2]*/\ + ymm7 = _mm256_fmadd_ps(ymm2, ymm0, ymm7);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 3)); /*A01[0][3]*/\ + ymm9 = _mm256_fmadd_ps(ymm2, ymm0, ymm9);\ \ a01 += 1; /*move to next row*/\ b10 += cs_b;\ } -/*GEMM block used in strsm small left cases*/ -#define BLIS_STRSM_SMALL_GEMM_16mx6n(a10,b01,cs_b,p_lda,k_iter) \ - float *b01_prefetch = b01 + 8; \ - for(k = 0; k< k_iter; k++) \ - { \ - ymm0 = _mm256_loadu_ps((float const *)(a10)); \ - ymm1 = _mm256_loadu_ps((float const *)(a10 + 8)); \ - _mm_prefetch((char*)( a10 + 64), _MM_HINT_T0); \ - /*Calculate the next micro pannel address to prefetch*/ \ - if(k & 0x7) b01_prefetch += cs_b; \ - else b01_prefetch = b01+ 8; \ - ymm2 = _mm256_broadcast_ss((float const *)(b01)); \ - ymm8 = _mm256_fmadd_ps(ymm2, ymm0, ymm8); \ - ymm12 = _mm256_fmadd_ps(ymm2, ymm1, ymm12); \ - \ - ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 1)); \ - ymm9 = _mm256_fmadd_ps(ymm2, ymm0, ymm9); \ - ymm13 = _mm256_fmadd_ps(ymm2, ymm1, ymm13); \ - \ - ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 2)); \ - ymm10 = _mm256_fmadd_ps(ymm2, ymm0, ymm10); \ - ymm14 = _mm256_fmadd_ps(ymm2, ymm1, ymm14); \ - \ - /*Prefetch the next 6x8 micro panelof B */ \ - _mm_prefetch((char*)( b01_prefetch), _MM_HINT_T0); \ - \ - ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 3)); \ - ymm11 = _mm256_fmadd_ps(ymm2, ymm0, ymm11); \ - ymm15 = _mm256_fmadd_ps(ymm2, ymm1, ymm15); \ - \ - ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 4)); \ - ymm4 = _mm256_fmadd_ps(ymm2, ymm0, ymm4); \ - ymm6 = _mm256_fmadd_ps(ymm2, ymm1, ymm6); \ - \ - ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 5)); \ - ymm5 = _mm256_fmadd_ps(ymm2, ymm0, ymm5); \ - ymm7 = _mm256_fmadd_ps(ymm2, ymm1, ymm7); \ - \ - b01 += 1; \ - a10 += p_lda; \ - } - -#define BLIS_STRSM_SMALL_GEMM_16mx4n(a10,b01,cs_b,p_lda,k_iter) \ - for(k = 0; k< k_iter; k++) /*loop for number of GEMM operations*/\ +#define BLIS_STRSM_SMALL_GEMM_4nx4m(a01,b10,cs_b,p_lda,k_iter) \ + for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ {\ - ymm0 = _mm256_loadu_ps((float const *)(a10));\ - ymm1 = _mm256_loadu_ps((float const *)(a10 + 8));\ + /*load 4x1 block of B10*/\ + xmm5 = _mm_loadu_ps((float const*)(b10));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ \ - ymm2 = _mm256_broadcast_ss((float const *)(b01));\ - ymm8 = _mm256_fmadd_ps(ymm2, ymm0, ymm8);\ - ymm12 = _mm256_fmadd_ps(ymm2, ymm1, ymm12);\ + /*broadcast 1st row of A01*/\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ + ymm3 = _mm256_fmadd_ps(ymm2, ymm0, ymm3);\ \ - ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 1));\ - ymm9 = _mm256_fmadd_ps(ymm2, ymm0, ymm9);\ - ymm13 = _mm256_fmadd_ps(ymm2, ymm1, ymm13);\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 1)); /*A01[0][1]*/\ + ymm5 = _mm256_fmadd_ps(ymm2, ymm0, ymm5);\ \ - ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 2));\ - ymm10 = _mm256_fmadd_ps(ymm2, ymm0, ymm10);\ - ymm14 = _mm256_fmadd_ps(ymm2, ymm1, ymm14);\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 2)); /*A01[0][2]*/\ + ymm7 = _mm256_fmadd_ps(ymm2, ymm0, ymm7);\ \ - ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 3));\ - ymm11 = _mm256_fmadd_ps(ymm2, ymm0, ymm11);\ - ymm15 = _mm256_fmadd_ps(ymm2, ymm1, ymm15);\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 3)); /*A01[0][3]*/\ + ymm9 = _mm256_fmadd_ps(ymm2, ymm0, ymm9);\ \ - b01 += 1; /*move to next row of B*/\ - a10 += p_lda; /*pointer math to calculate next block of A for GEMM*/\ + a01 += 1; /*move to next row*/\ + b10 += cs_b;\ } -#define BLIS_STRSM_SMALL_GEMM_16mx3n(a10,b01,cs_b,p_lda,k_iter) \ - for(k = 0; k< k_iter; k++) /*loop for number of GEMM operations*/\ +#define BLIS_STRSM_SMALL_GEMM_4nx3m(a01,b10,cs_b,p_lda,k_iter) \ + for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ {\ - ymm0 = _mm256_loadu_ps((float const *)(a10));\ - ymm1 = _mm256_loadu_ps((float const *)(a10 + 8));\ + /*load 3x1 block of B10*/\ + __m128 xmm6 = _mm_broadcast_ss((float const *)(b10+ 2));\ + xmm5 = _mm_loadl_pi(xmm6,(__m64*)(b10)); \ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ \ - ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 0));\ - ymm8 = _mm256_fmadd_ps(ymm2, ymm0, ymm8);\ - ymm12 = _mm256_fmadd_ps(ymm2, ymm1, ymm12);\ + /*broadcast 1st row of A01*/\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ + ymm3 = _mm256_fmadd_ps(ymm2, ymm0, ymm3);\ \ - ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 1));\ - ymm9 = _mm256_fmadd_ps(ymm2, ymm0, ymm9);\ - ymm13 = _mm256_fmadd_ps(ymm2, ymm1, ymm13);\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 1)); /*A01[0][1]*/\ + ymm5 = _mm256_fmadd_ps(ymm2, ymm0, ymm5);\ \ - ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 2));\ - ymm10 = _mm256_fmadd_ps(ymm2, ymm0, ymm10);\ - ymm14 = _mm256_fmadd_ps(ymm2, ymm1, ymm14);\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 2)); /*A01[0][2]*/\ + ymm7 = _mm256_fmadd_ps(ymm2, ymm0, ymm7);\ \ - b01 += 1; /*move to next row of B*/\ - a10 += p_lda;/*pointer math to calculate next block of A for GEMM*/\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 3)); /*A01[0][3]*/\ + ymm9 = _mm256_fmadd_ps(ymm2, ymm0, ymm9);\ +\ + a01 += 1; /*move to next row*/\ + b10 += cs_b;\ } -#define BLIS_STRSM_SMALL_GEMM_16mx2n(a10,b01,cs_b,p_lda,k_iter) \ - for(k = 0; k< k_iter; k++) /*loop for number of GEMM operations*/\ +#define BLIS_STRSM_SMALL_GEMM_4nx2m(a01,b10,cs_b,p_lda,k_iter) \ + for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ {\ - ymm0 = _mm256_loadu_ps((float const *)(a10));\ - ymm1 = _mm256_loadu_ps((float const *)(a10 + 8));\ + /*load 2x1 block of B10*/\ + xmm5 = _mm_setzero_ps();\ + xmm5 = _mm_loadl_pi(xmm5,(__m64*)(b10));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ \ - ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 0));\ - ymm8 = _mm256_fmadd_ps(ymm2, ymm0, ymm8);\ - ymm12 = _mm256_fmadd_ps(ymm2, ymm1, ymm12);\ + /*broadcast 1st row of A01*/\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ + ymm3 = _mm256_fmadd_ps(ymm2, ymm0, ymm3);\ \ - ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 1));\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 1)); /*A01[0][1]*/\ + ymm5 = _mm256_fmadd_ps(ymm2, ymm0, ymm5);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 2)); /*A01[0][2]*/\ + ymm7 = _mm256_fmadd_ps(ymm2, ymm0, ymm7);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 3)); /*A01[0][3]*/\ ymm9 = _mm256_fmadd_ps(ymm2, ymm0, ymm9);\ - ymm13 = _mm256_fmadd_ps(ymm2, ymm1, ymm13);\ \ - b01 += 1; /*move to next row of B*/\ - a10 += p_lda;/*pointer math to calculate next block of A for GEMM*/\ + a01 += 1; /*move to next row*/\ + b10 += cs_b;\ } -#define BLIS_STRSM_SMALL_GEMM_16mx1n(a10,b01,cs_b,p_lda,k_iter) \ - for(k = 0; k< k_iter; k++) /*loop for number of GEMM operations*/\ +#define BLIS_STRSM_SMALL_GEMM_4nx1m(a01,b10,cs_b,p_lda,k_iter) \ + for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ {\ - ymm0 = _mm256_loadu_ps((float const *)(a10));\ - ymm1 = _mm256_loadu_ps((float const *)(a10 + 8));\ + /*load 1x1 block of B10*/\ + ymm0 = _mm256_broadcast_ss((float const *)b10);\ \ - ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 0));\ - ymm8 = _mm256_fmadd_ps(ymm2, ymm0, ymm8);\ - ymm12 = _mm256_fmadd_ps(ymm2, ymm1, ymm12);\ - b01 += 1; /*move to next row of B*/\ - a10 += p_lda;/*pointer math to calculate next block of A for GEMM*/\ - } - -#define BLIS_STRSM_SMALL_GEMM_8mx6n(a10,b01,cs_b,p_lda,k_iter) \ - for(k = 0; k< k_iter; k++) /*loop for number of GEMM operations*/\ - {\ - ymm0 = _mm256_loadu_ps((float const *)(a10));\ + /*broadcast 1st row of A01*/\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ + ymm3 = _mm256_fmadd_ps(ymm2, ymm0, ymm3);\ \ - ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 0));\ - ymm8 = _mm256_fmadd_ps(ymm2, ymm0, ymm8);\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 1)); /*A01[0][1]*/\ + ymm5 = _mm256_fmadd_ps(ymm2, ymm0, ymm5);\ \ - ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 1));\ - ymm9 = _mm256_fmadd_ps(ymm2, ymm0, ymm9);\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 2)); /*A01[0][2]*/\ + ymm7 = _mm256_fmadd_ps(ymm2, ymm0, ymm7);\ \ - ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 2));\ - ymm10 = _mm256_fmadd_ps(ymm2, ymm0, ymm10);\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 3)); /*A01[0][3]*/\ + ymm9 = _mm256_fmadd_ps(ymm2, ymm0, ymm9);\ \ - ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 3));\ - ymm11 = _mm256_fmadd_ps(ymm2, ymm0, ymm11);\ + a01 += 1; /*move to next row*/\ + b10 += cs_b;\ + } + +#define BLIS_STRSM_SMALL_GEMM_3nx8m(a01,b10,cs_b,p_lda,k_iter) \ + for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + /*load 8x1 block of B10*/\ + ymm0 = _mm256_loadu_ps((float const *)b10);/*B10[0][0] B10[1][0] B10[2][0] B10[3][0]*/\ \ - ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 4));\ - ymm4 = _mm256_fmadd_ps(ymm2, ymm0, ymm4);\ + /*broadcast 1st row of A01*/\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ + ymm3 = _mm256_fmadd_ps(ymm2, ymm0, ymm3);\ \ - ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 5));\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 1)); /*A01[0][1]*/\ ymm5 = _mm256_fmadd_ps(ymm2, ymm0, ymm5);\ \ - b01 += 1; /*move to next row of B*/\ - a10 += p_lda;/*pointer math to calculate next block of A for GEMM*/\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 2)); /*A01[0][2]*/\ + ymm7 = _mm256_fmadd_ps(ymm2, ymm0, ymm7);\ +\ + a01 += 1; /*move to next row*/\ + b10 += cs_b;\ } -#define BLIS_STRSM_SMALL_GEMM_8mx4n(a10,b01,cs_b,p_lda,k_iter) \ - for(k = 0; k< k_iter; k++) /*loop for number of GEMM operations*/\ +#define BLIS_STRSM_SMALL_GEMM_3nx7m(a01,b10,cs_b,p_lda,k_iter) \ + for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ {\ - ymm0 = _mm256_loadu_ps((float const *)(a10));\ -\ - ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 0));\ - ymm8 = _mm256_fmadd_ps(ymm2, ymm0, ymm8);\ + /*load 7x1 block of B10*/\ + xmm5 = _mm_loadu_ps((float const*)(b10));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ + xmm5 = _mm_broadcast_ss((float const*)(b10 + 6));\ + xmm5 = _mm_loadl_pi(xmm5,(__m64*)(b10 + 4));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 1);\ \ - ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 1));\ - ymm9 = _mm256_fmadd_ps(ymm2, ymm0, ymm9);\ + /*broadcast 1st row of A01*/\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ + ymm3 = _mm256_fmadd_ps(ymm2, ymm0, ymm3);\ \ - ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 2));\ - ymm10 = _mm256_fmadd_ps(ymm2, ymm0, ymm10);\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 1)); /*A01[0][1]*/\ + ymm5 = _mm256_fmadd_ps(ymm2, ymm0, ymm5);\ \ - ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 3));\ - ymm11 = _mm256_fmadd_ps(ymm2, ymm0, ymm11);\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 2)); /*A01[0][2]*/\ + ymm7 = _mm256_fmadd_ps(ymm2, ymm0, ymm7);\ \ - b01 += 1; /*move to next row of B*/\ - a10 += p_lda;/*pointer math to calculate next block of A for GEMM*/\ + a01 += 1; /*move to next row*/\ + b10 += cs_b;\ } -#define BLIS_STRSM_SMALL_GEMM_8mx3n(a10,b01,cs_b,p_lda,k_iter) \ - for(k = 0; k< k_iter; k++) /*loop for number of GEMM operations*/\ +#define BLIS_STRSM_SMALL_GEMM_3nx6m(a01,b10,cs_b,p_lda,k_iter) \ + for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ {\ - ymm0 = _mm256_loadu_ps((float const *)(a10));\ + /*load 6x1 block of B10*/\ + xmm5 = _mm_loadu_ps((float const*)(b10));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ + xmm5 = _mm_loadl_pi(xmm5,(__m64*)(b10 + 4));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 1);\ \ - ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 0));\ - ymm8 = _mm256_fmadd_ps(ymm2, ymm0, ymm8);\ + /*broadcast 1st row of A01*/\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ + ymm3 = _mm256_fmadd_ps(ymm2, ymm0, ymm3);\ \ - ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 1));\ - ymm9 = _mm256_fmadd_ps(ymm2, ymm0, ymm9);\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 1)); /*A01[0][1]*/\ + ymm5 = _mm256_fmadd_ps(ymm2, ymm0, ymm5);\ \ - ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 2));\ - ymm10 = _mm256_fmadd_ps(ymm2, ymm0, ymm10);\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 2)); /*A01[0][2]*/\ + ymm7 = _mm256_fmadd_ps(ymm2, ymm0, ymm7);\ \ - b01 += 1; /*move to next row of B*/\ - a10 += p_lda;/*pointer math to calculate next block of A for GEMM*/\ + a01 += 1; /*move to next row*/\ + b10 += cs_b;\ } -#define BLIS_STRSM_SMALL_GEMM_8mx2n(a10,b01,cs_b,p_lda,k_iter) \ - for(k = 0; k< k_iter; k++) /*loop for number of GEMM operations*/\ +#define BLIS_STRSM_SMALL_GEMM_3nx5m(a01,b10,cs_b,p_lda,k_iter) \ + for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ {\ - ymm0 = _mm256_loadu_ps((float const *)(a10));\ + /*load 5x1 block of B10*/\ + ymm0 = _mm256_broadcast_ss((float const *)(b10 + 4));\ + xmm5 = _mm_loadu_ps((float const*)(b10));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ \ - ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 0));\ - ymm8 = _mm256_fmadd_ps(ymm2, ymm0, ymm8);\ -\ - ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 1));\ - ymm9 = _mm256_fmadd_ps(ymm2, ymm0, ymm9);\ + /*broadcast 1st row of A01*/\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ + ymm3 = _mm256_fmadd_ps(ymm2, ymm0, ymm3);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 1)); /*A01[0][1]*/\ + ymm5 = _mm256_fmadd_ps(ymm2, ymm0, ymm5);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 2)); /*A01[0][2]*/\ + ymm7 = _mm256_fmadd_ps(ymm2, ymm0, ymm7);\ +\ + a01 += 1; /*move to next row*/\ + b10 += cs_b;\ + } + +#define BLIS_STRSM_SMALL_GEMM_3nx4m(a01,b10,cs_b,p_lda,k_iter) \ + for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + /*load 4x1 block of B10*/\ + xmm5 = _mm_loadu_ps((float const*)(b10));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + /*broadcast 1st row of A01*/\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ + ymm3 = _mm256_fmadd_ps(ymm2, ymm0, ymm3);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 1)); /*A01[0][1]*/\ + ymm5 = _mm256_fmadd_ps(ymm2, ymm0, ymm5);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 2)); /*A01[0][2]*/\ + ymm7 = _mm256_fmadd_ps(ymm2, ymm0, ymm7);\ +\ + a01 += 1; /*move to next row*/\ + b10 += cs_b;\ + } + +#define BLIS_STRSM_SMALL_GEMM_3nx3m(a01,b10,cs_b,p_lda,k_iter) \ + for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + /*load 3x1 block of B10*/\ + __m128 xmm6 = _mm_broadcast_ss((float const *)(b10+ 2));\ + xmm5 = _mm_loadl_pi(xmm6,(__m64*)(b10)); \ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + /*broadcast 1st row of A01*/\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ + ymm3 = _mm256_fmadd_ps(ymm2, ymm0, ymm3);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 1)); /*A01[0][1]*/\ + ymm5 = _mm256_fmadd_ps(ymm2, ymm0, ymm5);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 2)); /*A01[0][2]*/\ + ymm7 = _mm256_fmadd_ps(ymm2, ymm0, ymm7);\ +\ + a01 += 1; /*move to next row*/\ + b10 += cs_b;\ + } + +#define BLIS_STRSM_SMALL_GEMM_3nx2m(a01,b10,cs_b,p_lda,k_iter) \ + for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + /*load 2x1 block of B10*/\ + xmm5 = _mm_setzero_ps();\ + xmm5 = _mm_loadl_pi(xmm5,(__m64*)(b10));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + /*broadcast 1st row of A01*/\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ + ymm3 = _mm256_fmadd_ps(ymm2, ymm0, ymm3);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 1)); /*A01[0][1]*/\ + ymm5 = _mm256_fmadd_ps(ymm2, ymm0, ymm5);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 2)); /*A01[0][2]*/\ + ymm7 = _mm256_fmadd_ps(ymm2, ymm0, ymm7);\ +\ + a01 += 1; /*move to next row*/\ + b10 += cs_b;\ + } + +#define BLIS_STRSM_SMALL_GEMM_3nx1m(a01,b10,cs_b,p_lda,k_iter) \ + for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + /*load 1x1 block of B10*/\ + ymm0 = _mm256_broadcast_ss((float const *)b10);\ +\ + /*broadcast 1st row of A01*/\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ + ymm3 = _mm256_fmadd_ps(ymm2, ymm0, ymm3);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 1)); /*A01[0][1]*/\ + ymm5 = _mm256_fmadd_ps(ymm2, ymm0, ymm5);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 2)); /*A01[0][2]*/\ + ymm7 = _mm256_fmadd_ps(ymm2, ymm0, ymm7);\ +\ + a01 += 1; /*move to next row*/\ + b10 += cs_b;\ + } + +#define BLIS_STRSM_SMALL_GEMM_2nx8m(a01,b10,cs_b,p_lda,k_iter) \ + for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + /*load 8x1 block of B10*/\ + ymm0 = _mm256_loadu_ps((float const *)b10);/*B10[0][0] B10[1][0] B10[2][0] B10[3][0]*/\ +\ + /*broadcast 1st row of A01*/\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ + ymm3 = _mm256_fmadd_ps(ymm2, ymm0, ymm3);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 1)); /*A01[0][1]*/\ + ymm5 = _mm256_fmadd_ps(ymm2, ymm0, ymm5);\ +\ + a01 += 1; /*move to next row*/\ + b10 += cs_b;\ + } + +#define BLIS_STRSM_SMALL_GEMM_2nx7m(a01,b10,cs_b,p_lda,k_iter) \ + for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + /*load 7x1 block of B10*/\ + xmm5 = _mm_loadu_ps((float const*)(b10));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ + xmm5 = _mm_broadcast_ss((float const*)(b10 + 6));\ + xmm5 = _mm_loadl_pi(xmm5,(__m64*)(b10 + 4));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 1);\ +\ + /*broadcast 1st row of A01*/\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ + ymm3 = _mm256_fmadd_ps(ymm2, ymm0, ymm3);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 1)); /*A01[0][1]*/\ + ymm5 = _mm256_fmadd_ps(ymm2, ymm0, ymm5);\ +\ + a01 += 1; /*move to next row*/\ + b10 += cs_b;\ + } + +#define BLIS_STRSM_SMALL_GEMM_2nx6m(a01,b10,cs_b,p_lda,k_iter) \ + for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + /*load 6x1 block of B10*/\ + xmm5 = _mm_loadu_ps((float const*)(b10));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ + xmm5 = _mm_loadl_pi(xmm5,(__m64*)(b10 + 4));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 1);\ +\ + /*broadcast 1st row of A01*/\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ + ymm3 = _mm256_fmadd_ps(ymm2, ymm0, ymm3);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 1)); /*A01[0][1]*/\ + ymm5 = _mm256_fmadd_ps(ymm2, ymm0, ymm5);\ +\ + a01 += 1; /*move to next row*/\ + b10 += cs_b;\ + } + +#define BLIS_STRSM_SMALL_GEMM_2nx5m(a01,b10,cs_b,p_lda,k_iter) \ + for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + /*load 5x1 block of B10*/\ + ymm0 = _mm256_broadcast_ss((float const *)(b10 + 4));\ + xmm5 = _mm_loadu_ps((float const*)(b10));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + /*broadcast 1st row of A01*/\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ + ymm3 = _mm256_fmadd_ps(ymm2, ymm0, ymm3);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 1)); /*A01[0][1]*/\ + ymm5 = _mm256_fmadd_ps(ymm2, ymm0, ymm5);\ +\ + a01 += 1; /*move to next row*/\ + b10 += cs_b;\ + } + +#define BLIS_STRSM_SMALL_GEMM_2nx4m(a01,b10,cs_b,p_lda,k_iter) \ + for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + /*load 4x1 block of B10*/\ + xmm5 = _mm_loadu_ps((float const*)(b10));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + /*broadcast 1st row of A01*/\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ + ymm3 = _mm256_fmadd_ps(ymm2, ymm0, ymm3);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 1)); /*A01[0][1]*/\ + ymm5 = _mm256_fmadd_ps(ymm2, ymm0, ymm5);\ +\ + a01 += 1; /*move to next row*/\ + b10 += cs_b;\ + } + +#define BLIS_STRSM_SMALL_GEMM_2nx3m(a01,b10,cs_b,p_lda,k_iter) \ + for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + /*load 3x1 block of B10*/\ + __m128 xmm6 = _mm_broadcast_ss((float const *)(b10+ 2));\ + xmm5 = _mm_loadl_pi(xmm6,(__m64*)(b10)); \ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + /*broadcast 1st row of A01*/\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ + ymm3 = _mm256_fmadd_ps(ymm2, ymm0, ymm3);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 1)); /*A01[0][1]*/\ + ymm5 = _mm256_fmadd_ps(ymm2, ymm0, ymm5);\ +\ + a01 += 1; /*move to next row*/\ + b10 += cs_b;\ + } + +#define BLIS_STRSM_SMALL_GEMM_2nx2m(a01,b10,cs_b,p_lda,k_iter) \ + for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + /*load 2x1 block of B10*/\ + xmm5 = _mm_setzero_ps();\ + xmm5 = _mm_loadl_pi(xmm5,(__m64*)(b10));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + /*broadcast 1st row of A01*/\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ + ymm3 = _mm256_fmadd_ps(ymm2, ymm0, ymm3);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 1)); /*A01[0][1]*/\ + ymm5 = _mm256_fmadd_ps(ymm2, ymm0, ymm5);\ +\ + a01 += 1; /*move to next row*/\ + b10 += cs_b;\ + } + +#define BLIS_STRSM_SMALL_GEMM_2nx1m(a01,b10,cs_b,p_lda,k_iter) \ + for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + /*load 1x1 block of B10*/\ + ymm0 = _mm256_broadcast_ss((float const *)b10);\ +\ + /*broadcast 1st row of A01*/\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ + ymm3 = _mm256_fmadd_ps(ymm2, ymm0, ymm3);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 1)); /*A01[0][1]*/\ + ymm5 = _mm256_fmadd_ps(ymm2, ymm0, ymm5);\ +\ + a01 += 1; /*move to next row*/\ + b10 += cs_b;\ + } + +#define BLIS_STRSM_SMALL_GEMM_1nx8m(a01,b10,cs_b,p_lda,k_iter) \ + for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + /*load 8x1 block of B10*/\ + ymm0 = _mm256_loadu_ps((float const *)b10);/*B10[0][0] B10[1][0] B10[2][0] B10[3][0]*/\ +\ + /*broadcast 1st row of A01*/\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ + ymm3 = _mm256_fmadd_ps(ymm2, ymm0, ymm3);\ +\ + a01 += 1; /*move to next row*/\ + b10 += cs_b;\ + } + +#define BLIS_STRSM_SMALL_GEMM_1nx7m(a01,b10,cs_b,p_lda,k_iter) \ + for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + /*load 7x1 block of B10*/\ + xmm5 = _mm_loadu_ps((float const*)(b10));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ + xmm5 = _mm_broadcast_ss((float const*)(b10 + 6));\ + xmm5 = _mm_loadl_pi(xmm5,(__m64*)(b10 + 4));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 1);\ +\ + /*broadcast 1st row of A01*/\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ + ymm3 = _mm256_fmadd_ps(ymm2, ymm0, ymm3);\ +\ + a01 += 1; /*move to next row*/\ + b10 += cs_b;\ + } + +#define BLIS_STRSM_SMALL_GEMM_1nx6m(a01,b10,cs_b,p_lda,k_iter) \ + for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + /*load 6x1 block of B10*/\ + xmm5 = _mm_loadu_ps((float const*)(b10));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ + xmm5 = _mm_loadl_pi(xmm5,(__m64*)(b10 + 4));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 1);\ +\ + /*broadcast 1st row of A01*/\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ + ymm3 = _mm256_fmadd_ps(ymm2, ymm0, ymm3);\ +\ + a01 += 1; /*move to next row*/\ + b10 += cs_b;\ + } + +#define BLIS_STRSM_SMALL_GEMM_1nx5m(a01,b10,cs_b,p_lda,k_iter) \ + for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + /*load 5x1 block of B10*/\ + ymm0 = _mm256_broadcast_ss((float const *)(b10 + 4));\ + xmm5 = _mm_loadu_ps((float const*)(b10));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + /*broadcast 1st row of A01*/\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ + ymm3 = _mm256_fmadd_ps(ymm2, ymm0, ymm3);\ +\ + a01 += 1; /*move to next row*/\ + b10 += cs_b;\ + } + +#define BLIS_STRSM_SMALL_GEMM_1nx4m(a01,b10,cs_b,p_lda,k_iter) \ + for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + /*load 4x1 block of B10*/\ + xmm5 = _mm_loadu_ps((float const*)(b10));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + /*broadcast 1st row of A01*/\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ + ymm3 = _mm256_fmadd_ps(ymm2, ymm0, ymm3);\ +\ + a01 += 1; /*move to next row*/\ + b10 += cs_b;\ + } + +#define BLIS_STRSM_SMALL_GEMM_1nx3m(a01,b10,cs_b,p_lda,k_iter) \ + for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + /*load 3x1 block of B10*/\ + __m128 xmm6 = _mm_broadcast_ss((float const *)(b10+ 2));\ + xmm5 = _mm_loadl_pi(xmm6,(__m64*)(b10)); \ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + /*broadcast 1st row of A01*/\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ + ymm3 = _mm256_fmadd_ps(ymm2, ymm0, ymm3);\ +\ + a01 += 1; /*move to next row*/\ + b10 += cs_b;\ + } + +#define BLIS_STRSM_SMALL_GEMM_1nx2m(a01,b10,cs_b,p_lda,k_iter) \ + for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + /*load 2x1 block of B10*/\ + xmm5 = _mm_setzero_ps();\ + xmm5 = _mm_loadl_pi(xmm5,(__m64*)(b10));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + /*broadcast 1st row of A01*/\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ + ymm3 = _mm256_fmadd_ps(ymm2, ymm0, ymm3);\ +\ + a01 += 1; /*move to next row*/\ + b10 += cs_b;\ + } + +#define BLIS_STRSM_SMALL_GEMM_1nx1m(a01,b10,cs_b,p_lda,k_iter) \ + for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + /*load 1x1 block of B10*/\ + ymm0 = _mm256_broadcast_ss((float const *)b10);\ +\ + /*broadcast 1st row of A01*/\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ + ymm3 = _mm256_fmadd_ps(ymm2, ymm0, ymm3);\ +\ + a01 += 1; /*move to next row*/\ + b10 += cs_b;\ + } + +/*GEMM block used in strsm small left cases*/ +#define BLIS_STRSM_SMALL_GEMM_16mx6n(a10,b01,cs_b,p_lda,k_iter) \ + float *b01_prefetch = b01 + 8; \ + for(k = 0; k< k_iter; k++) \ + { \ + ymm0 = _mm256_loadu_ps((float const *)(a10)); \ + ymm1 = _mm256_loadu_ps((float const *)(a10 + 8)); \ + _mm_prefetch((char*)( a10 + 64), _MM_HINT_T0); \ + /*Calculate the next micro pannel address to prefetch*/ \ + if(k & 0x7) b01_prefetch += cs_b; \ + else b01_prefetch = b01+ 8; \ + ymm2 = _mm256_broadcast_ss((float const *)(b01)); \ + ymm8 = _mm256_fmadd_ps(ymm2, ymm0, ymm8); \ + ymm12 = _mm256_fmadd_ps(ymm2, ymm1, ymm12); \ + \ + ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 1)); \ + ymm9 = _mm256_fmadd_ps(ymm2, ymm0, ymm9); \ + ymm13 = _mm256_fmadd_ps(ymm2, ymm1, ymm13); \ + \ + ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 2)); \ + ymm10 = _mm256_fmadd_ps(ymm2, ymm0, ymm10); \ + ymm14 = _mm256_fmadd_ps(ymm2, ymm1, ymm14); \ + \ + /*Prefetch the next 6x8 micro panelof B */ \ + _mm_prefetch((char*)( b01_prefetch), _MM_HINT_T0); \ + \ + ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 3)); \ + ymm11 = _mm256_fmadd_ps(ymm2, ymm0, ymm11); \ + ymm15 = _mm256_fmadd_ps(ymm2, ymm1, ymm15); \ + \ + ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 4)); \ + ymm4 = _mm256_fmadd_ps(ymm2, ymm0, ymm4); \ + ymm6 = _mm256_fmadd_ps(ymm2, ymm1, ymm6); \ + \ + ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 5)); \ + ymm5 = _mm256_fmadd_ps(ymm2, ymm0, ymm5); \ + ymm7 = _mm256_fmadd_ps(ymm2, ymm1, ymm7); \ + \ + b01 += 1; \ + a10 += p_lda; \ + } + +#define BLIS_STRSM_SMALL_GEMM_16mx4n(a10,b01,cs_b,p_lda,k_iter) \ + for(k = 0; k< k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + ymm0 = _mm256_loadu_ps((float const *)(a10));\ + ymm1 = _mm256_loadu_ps((float const *)(a10 + 8));\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(b01));\ + ymm8 = _mm256_fmadd_ps(ymm2, ymm0, ymm8);\ + ymm12 = _mm256_fmadd_ps(ymm2, ymm1, ymm12);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 1));\ + ymm9 = _mm256_fmadd_ps(ymm2, ymm0, ymm9);\ + ymm13 = _mm256_fmadd_ps(ymm2, ymm1, ymm13);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 2));\ + ymm10 = _mm256_fmadd_ps(ymm2, ymm0, ymm10);\ + ymm14 = _mm256_fmadd_ps(ymm2, ymm1, ymm14);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 3));\ + ymm11 = _mm256_fmadd_ps(ymm2, ymm0, ymm11);\ + ymm15 = _mm256_fmadd_ps(ymm2, ymm1, ymm15);\ +\ + b01 += 1; /*move to next row of B*/\ + a10 += p_lda; /*pointer math to calculate next block of A for GEMM*/\ + } + +#define BLIS_STRSM_SMALL_GEMM_16mx3n(a10,b01,cs_b,p_lda,k_iter) \ + for(k = 0; k< k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + ymm0 = _mm256_loadu_ps((float const *)(a10));\ + ymm1 = _mm256_loadu_ps((float const *)(a10 + 8));\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 0));\ + ymm8 = _mm256_fmadd_ps(ymm2, ymm0, ymm8);\ + ymm12 = _mm256_fmadd_ps(ymm2, ymm1, ymm12);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 1));\ + ymm9 = _mm256_fmadd_ps(ymm2, ymm0, ymm9);\ + ymm13 = _mm256_fmadd_ps(ymm2, ymm1, ymm13);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 2));\ + ymm10 = _mm256_fmadd_ps(ymm2, ymm0, ymm10);\ + ymm14 = _mm256_fmadd_ps(ymm2, ymm1, ymm14);\ +\ + b01 += 1; /*move to next row of B*/\ + a10 += p_lda;/*pointer math to calculate next block of A for GEMM*/\ + } + +#define BLIS_STRSM_SMALL_GEMM_16mx2n(a10,b01,cs_b,p_lda,k_iter) \ + for(k = 0; k< k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + ymm0 = _mm256_loadu_ps((float const *)(a10));\ + ymm1 = _mm256_loadu_ps((float const *)(a10 + 8));\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 0));\ + ymm8 = _mm256_fmadd_ps(ymm2, ymm0, ymm8);\ + ymm12 = _mm256_fmadd_ps(ymm2, ymm1, ymm12);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 1));\ + ymm9 = _mm256_fmadd_ps(ymm2, ymm0, ymm9);\ + ymm13 = _mm256_fmadd_ps(ymm2, ymm1, ymm13);\ +\ + b01 += 1; /*move to next row of B*/\ + a10 += p_lda;/*pointer math to calculate next block of A for GEMM*/\ + } + +#define BLIS_STRSM_SMALL_GEMM_16mx1n(a10,b01,cs_b,p_lda,k_iter) \ + for(k = 0; k< k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + ymm0 = _mm256_loadu_ps((float const *)(a10));\ + ymm1 = _mm256_loadu_ps((float const *)(a10 + 8));\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 0));\ + ymm8 = _mm256_fmadd_ps(ymm2, ymm0, ymm8);\ + ymm12 = _mm256_fmadd_ps(ymm2, ymm1, ymm12);\ + b01 += 1; /*move to next row of B*/\ + a10 += p_lda;/*pointer math to calculate next block of A for GEMM*/\ + } + +#define BLIS_STRSM_SMALL_GEMM_8mx6n(a10,b01,cs_b,p_lda,k_iter) \ + for(k = 0; k< k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + ymm0 = _mm256_loadu_ps((float const *)(a10));\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 0));\ + ymm8 = _mm256_fmadd_ps(ymm2, ymm0, ymm8);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 1));\ + ymm9 = _mm256_fmadd_ps(ymm2, ymm0, ymm9);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 2));\ + ymm10 = _mm256_fmadd_ps(ymm2, ymm0, ymm10);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 3));\ + ymm11 = _mm256_fmadd_ps(ymm2, ymm0, ymm11);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 4));\ + ymm4 = _mm256_fmadd_ps(ymm2, ymm0, ymm4);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 5));\ + ymm5 = _mm256_fmadd_ps(ymm2, ymm0, ymm5);\ +\ + b01 += 1; /*move to next row of B*/\ + a10 += p_lda;/*pointer math to calculate next block of A for GEMM*/\ + } + +#define BLIS_STRSM_SMALL_GEMM_8mx4n(a10,b01,cs_b,p_lda,k_iter) \ + for(k = 0; k< k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + ymm0 = _mm256_loadu_ps((float const *)(a10));\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 0));\ + ymm8 = _mm256_fmadd_ps(ymm2, ymm0, ymm8);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 1));\ + ymm9 = _mm256_fmadd_ps(ymm2, ymm0, ymm9);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 2));\ + ymm10 = _mm256_fmadd_ps(ymm2, ymm0, ymm10);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 3));\ + ymm11 = _mm256_fmadd_ps(ymm2, ymm0, ymm11);\ +\ + b01 += 1; /*move to next row of B*/\ + a10 += p_lda;/*pointer math to calculate next block of A for GEMM*/\ + } + +#define BLIS_STRSM_SMALL_GEMM_8mx3n(a10,b01,cs_b,p_lda,k_iter) \ + for(k = 0; k< k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + ymm0 = _mm256_loadu_ps((float const *)(a10));\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 0));\ + ymm8 = _mm256_fmadd_ps(ymm2, ymm0, ymm8);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 1));\ + ymm9 = _mm256_fmadd_ps(ymm2, ymm0, ymm9);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 2));\ + ymm10 = _mm256_fmadd_ps(ymm2, ymm0, ymm10);\ +\ + b01 += 1; /*move to next row of B*/\ + a10 += p_lda;/*pointer math to calculate next block of A for GEMM*/\ + } + +#define BLIS_STRSM_SMALL_GEMM_8mx2n(a10,b01,cs_b,p_lda,k_iter) \ + for(k = 0; k< k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + ymm0 = _mm256_loadu_ps((float const *)(a10));\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 0));\ + ymm8 = _mm256_fmadd_ps(ymm2, ymm0, ymm8);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 1));\ + ymm9 = _mm256_fmadd_ps(ymm2, ymm0, ymm9);\ \ b01 += 1; /*move to next row of B*/\ a10 += p_lda;/*pointer math to calculate next block of A for GEMM*/\ @@ -3226,6 +4408,280 @@ BLIS_INLINE err_t dtrsm_XAltB_ref ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*5));\ ymm13 = _mm256_fmsub_ps(ymm0, ymm15, ymm13); +#define BLIS_PRE_STRSM_SMALL_6x7(AlphaVal,b11,cs_b)\ + ymm15 = _mm256_broadcast_ss((float const *)&AlphaVal); /*register to hold alpha*/\ +\ + xmm5 = _mm_loadu_ps((float const *)(b11));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + xmm5 = _mm_broadcast_ss((float *)(b11 + 6));\ + xmm5 = _mm_loadl_pi(xmm5,(__m64 *)(b11 + 4));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 1);\ + \ + ymm3 = _mm256_fmsub_ps(ymm0, ymm15, ymm3);\ +\ + xmm5 = _mm_loadu_ps((float const *)(b11 + cs_b));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + xmm5 = _mm_broadcast_ss((float *)(b11 + cs_b + 6));\ + xmm5 = _mm_loadl_pi(xmm5,(__m64 *)(b11 + cs_b + 4));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 1);\ + \ + ymm5 = _mm256_fmsub_ps(ymm0, ymm15, ymm5);\ +\ + xmm5 = _mm_loadu_ps((float const *)(b11 + cs_b*2));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + xmm5 = _mm_broadcast_ss((float *)(b11 + cs_b*2 + 6));\ + xmm5 = _mm_loadl_pi(xmm5,(__m64 *)(b11 + cs_b*2 + 4));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 1);\ + \ + ymm7 = _mm256_fmsub_ps(ymm0, ymm15, ymm7);\ +\ + xmm5 = _mm_loadu_ps((float const *)(b11 + cs_b*3));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + xmm5 = _mm_broadcast_ss((float *)(b11 + cs_b*3 + 6));\ + xmm5 = _mm_loadl_pi(xmm5,(__m64 *)(b11 + cs_b*3 + 4));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 1);\ + \ + ymm9 = _mm256_fmsub_ps(ymm0, ymm15, ymm9);\ +\ + xmm5 = _mm_loadu_ps((float const *)(b11 + cs_b*4));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + xmm5 = _mm_broadcast_ss((float *)(b11 + cs_b*4 + 6));\ + xmm5 = _mm_loadl_pi(xmm5,(__m64 *)(b11 + cs_b*4 + 4));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 1);\ + \ + ymm11 = _mm256_fmsub_ps(ymm0, ymm15, ymm11);\ +\ + xmm5 = _mm_loadu_ps((float const *)(b11 + cs_b*5));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + xmm5 = _mm_broadcast_ss((float *)(b11 + cs_b*5 + 6));\ + xmm5 = _mm_loadl_pi(xmm5,(__m64 *)(b11 + cs_b*5 + 4));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 1);\ + \ + ymm13 = _mm256_fmsub_ps(ymm0, ymm15, ymm13); + +#define BLIS_PRE_STRSM_SMALL_6x6(AlphaVal,b11,cs_b)\ + ymm15 = _mm256_broadcast_ss((float const *)&AlphaVal); /*register to hold alpha*/\ +\ + xmm5 = _mm_loadu_ps((float const *)(b11));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + xmm5 = _mm_loadl_pi(xmm5,(__m64 *)(b11 + 4));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 1);\ +\ + ymm3 = _mm256_fmsub_ps(ymm0, ymm15, ymm3);\ +\ + xmm5 = _mm_loadu_ps((float const *)(b11 + cs_b));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + xmm5 = _mm_loadl_pi(xmm5,(__m64 *)(b11 + 4 + cs_b));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 1);\ + ymm5 = _mm256_fmsub_ps(ymm0, ymm15, ymm5);\ +\ + xmm5 = _mm_loadu_ps((float const *)(b11 + cs_b*2));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + xmm5 = _mm_loadl_pi(xmm5,(__m64 *)(b11 + 4 + cs_b*2));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 1);\ + ymm7 = _mm256_fmsub_ps(ymm0, ymm15, ymm7);\ +\ + xmm5 = _mm_loadu_ps((float const *)(b11 + cs_b*3));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + xmm5 = _mm_loadl_pi(xmm5,(__m64 *)(b11 + 4 + cs_b*3));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 1);\ + ymm9 = _mm256_fmsub_ps(ymm0, ymm15, ymm9);\ +\ + xmm5 = _mm_loadu_ps((float const *)(b11 + cs_b*4));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + xmm5 = _mm_loadl_pi(xmm5,(__m64 *)(b11 + 4 + cs_b*4));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 1);\ + ymm11 = _mm256_fmsub_ps(ymm0, ymm15, ymm11);\ +\ + xmm5 = _mm_loadu_ps((float const *)(b11 + cs_b*5));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + xmm5 = _mm_loadl_pi(xmm5,(__m64 *)(b11 + 4 + cs_b*5));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 1);\ + ymm13 = _mm256_fmsub_ps(ymm0, ymm15, ymm13); + +#define BLIS_PRE_STRSM_SMALL_6x5(AlphaVal,b11,cs_b)\ + ymm15 = _mm256_broadcast_ss((float const *)&AlphaVal); /*register to hold alpha*/\ +\ + ymm0 = _mm256_broadcast_ss((float const *)(b11 + 4));\ + xmm5 = _mm_loadu_ps((float const *)(b11));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + ymm3 = _mm256_fmsub_ps(ymm0, ymm15, ymm3);\ +\ + ymm0 = _mm256_broadcast_ss((float const *)(b11 + 4 + cs_b));\ + xmm5 = _mm_loadu_ps((float const *)(b11 + cs_b));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + ymm5 = _mm256_fmsub_ps(ymm0, ymm15, ymm5);\ +\ + ymm0 = _mm256_broadcast_ss((float const *)(b11 + 4 + cs_b*2));\ + xmm5 = _mm_loadu_ps((float const *)(b11 + cs_b*2));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + ymm7 = _mm256_fmsub_ps(ymm0, ymm15, ymm7);\ +\ + ymm0 = _mm256_broadcast_ss((float const *)(b11 + 4 + cs_b*3));\ + xmm5 = _mm_loadu_ps((float const *)(b11 + cs_b*3));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + ymm9 = _mm256_fmsub_ps(ymm0, ymm15, ymm9);\ +\ + ymm0 = _mm256_broadcast_ss((float const *)(b11 + 4 + cs_b*4));\ + xmm5 = _mm_loadu_ps((float const *)(b11 + cs_b*4));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + ymm11 = _mm256_fmsub_ps(ymm0, ymm15, ymm11);\ +\ + ymm0 = _mm256_broadcast_ss((float const *)(b11 + 4 + cs_b*5));\ + xmm5 = _mm_loadu_ps((float const *)(b11 + cs_b*5));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + ymm13 = _mm256_fmsub_ps(ymm0, ymm15, ymm13); + +#define BLIS_PRE_STRSM_SMALL_6x4(AlphaVal,b11,cs_b)\ + ymm15 = _mm256_broadcast_ss((float const *)&AlphaVal); /*register to hold alpha*/\ +\ + xmm5 = _mm_loadu_ps((float const *)(b11));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + ymm3 = _mm256_fmsub_ps(ymm0, ymm15, ymm3);\ +\ + xmm5 = _mm_loadu_ps((float const *)(b11 + cs_b));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + ymm5 = _mm256_fmsub_ps(ymm0, ymm15, ymm5);\ +\ + xmm5 = _mm_loadu_ps((float const *)(b11 + cs_b*2));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + ymm7 = _mm256_fmsub_ps(ymm0, ymm15, ymm7);\ +\ + xmm5 = _mm_loadu_ps((float const *)(b11 + cs_b*3));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + ymm9 = _mm256_fmsub_ps(ymm0, ymm15, ymm9);\ +\ + xmm5 = _mm_loadu_ps((float const *)(b11 + cs_b*4));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + ymm11 = _mm256_fmsub_ps(ymm0, ymm15, ymm11);\ +\ + xmm5 = _mm_loadu_ps((float const *)(b11 + cs_b*5));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + ymm13 = _mm256_fmsub_ps(ymm0, ymm15, ymm13); + +#define BLIS_PRE_STRSM_SMALL_6x3(AlphaVal,b11,cs_b)\ + ymm15 = _mm256_broadcast_ss((float const *)&AlphaVal); /*register to hold alpha*/\ +\ + xmm5 = _mm_broadcast_ss((float *)(b11 + 2));\ + xmm5 = _mm_loadl_pi(xmm5,(__m64 *)(b11));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + ymm3 = _mm256_fmsub_ps(ymm0, ymm15, ymm3);\ +\ + xmm5 = _mm_broadcast_ss((float *)(b11 + 2 + cs_b));\ + xmm5 = _mm_loadl_pi(xmm5,(__m64 *)(b11 + cs_b));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + ymm5 = _mm256_fmsub_ps(ymm0, ymm15, ymm5);\ +\ + xmm5 = _mm_broadcast_ss((float *)(b11 + 2 + cs_b*2));\ + xmm5 = _mm_loadl_pi(xmm5,(__m64 *)(b11 + cs_b*2));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + ymm7 = _mm256_fmsub_ps(ymm0, ymm15, ymm7);\ +\ + xmm5 = _mm_broadcast_ss((float *)(b11 + 2 + cs_b*3));\ + xmm5 = _mm_loadl_pi(xmm5,(__m64 *)(b11 + cs_b*3));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + ymm9 = _mm256_fmsub_ps(ymm0, ymm15, ymm9);\ +\ + xmm5 = _mm_broadcast_ss((float *)(b11 + 2 + cs_b*4));\ + xmm5 = _mm_loadl_pi(xmm5,(__m64 *)(b11 + cs_b*4));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + ymm11 = _mm256_fmsub_ps(ymm0, ymm15, ymm11);\ +\ + xmm5 = _mm_broadcast_ss((float *)(b11 + 2 + cs_b*5));\ + xmm5 = _mm_loadl_pi(xmm5,(__m64 *)(b11 + cs_b*5));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + ymm13 = _mm256_fmsub_ps(ymm0, ymm15, ymm13); + +#define BLIS_PRE_STRSM_SMALL_6x2(AlphaVal,b11,cs_b)\ + ymm15 = _mm256_broadcast_ss((float const *)&AlphaVal); /*register to hold alpha*/\ +\ + xmm5 = _mm_broadcast_ss((float *)&zero);\ + xmm5 = _mm_loadl_pi(xmm5,(__m64 *)(b11));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + ymm3 = _mm256_fmsub_ps(ymm0, ymm15, ymm3);\ +\ + xmm5 = _mm_broadcast_ss((float *)&zero);\ + xmm5 = _mm_loadl_pi(xmm5,(__m64 *)(b11 + cs_b));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + ymm5 = _mm256_fmsub_ps(ymm0, ymm15, ymm5);\ +\ + xmm5 = _mm_broadcast_ss((float *)&zero);\ + xmm5 = _mm_loadl_pi(xmm5,(__m64 *)(b11 + cs_b*2));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + ymm7 = _mm256_fmsub_ps(ymm0, ymm15, ymm7);\ +\ + xmm5 = _mm_broadcast_ss((float *)&zero);\ + xmm5 = _mm_loadl_pi(xmm5,(__m64 *)(b11 + cs_b*3));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + ymm9 = _mm256_fmsub_ps(ymm0, ymm15, ymm9);\ +\ + xmm5 = _mm_broadcast_ss((float *)&zero);\ + xmm5 = _mm_loadl_pi(xmm5,(__m64 *)(b11 + cs_b*4));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + ymm11 = _mm256_fmsub_ps(ymm0, ymm15, ymm11);\ +\ + xmm5 = _mm_broadcast_ss((float *)&zero);\ + xmm5 = _mm_loadl_pi(xmm5,(__m64 *)(b11 + cs_b*5));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + ymm13 = _mm256_fmsub_ps(ymm0, ymm15, ymm13); + +#define BLIS_PRE_STRSM_SMALL_6x1(AlphaVal,b11,cs_b)\ + ymm15 = _mm256_broadcast_ss((float const *)&AlphaVal); /*register to hold alpha*/\ +\ + ymm0 = _mm256_broadcast_ss((float const *)b11);\ + ymm3 = _mm256_fmsub_ps(ymm0, ymm15, ymm3);\ +\ + ymm0 = _mm256_broadcast_ss((float const *)(b11 + cs_b));\ + ymm5 = _mm256_fmsub_ps(ymm0, ymm15, ymm5);\ +\ + ymm0 = _mm256_broadcast_ss((float const *)(b11 + cs_b*2));\ + ymm7 = _mm256_fmsub_ps(ymm0, ymm15, ymm7);\ +\ + ymm0 = _mm256_broadcast_ss((float const *)(b11 + cs_b*3));\ + ymm9 = _mm256_fmsub_ps(ymm0, ymm15, ymm9);\ +\ + ymm0 = _mm256_broadcast_ss((float const *)(b11 + cs_b*4));\ + ymm11 = _mm256_fmsub_ps(ymm0, ymm15, ymm11);\ +\ + ymm0 = _mm256_broadcast_ss((float const *)(b11 + cs_b*5));\ + ymm13 = _mm256_fmsub_ps(ymm0, ymm15, ymm13); + /* Load b11 of size 6x8 and multiply with alpha Add the GEMM output and perform inregister transose of b11 @@ -3439,7 +4895,6 @@ BLIS_INLINE void bli_dtrsm_small_pack __m256d ymm8, ymm9, ymm10, ymm11; __m256d ymm12, ymm13; __m128d xmm0,xmm1,xmm2,xmm3; - double zero = 0.0; if(side=='L'||side=='l') { @@ -3595,12 +5050,10 @@ BLIS_INLINE void bli_dtrsm_small_pack ymm4 = _mm256_unpacklo_pd(ymm10, ymm11); ymm5 = _mm256_unpacklo_pd(ymm12, ymm13); ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); - ymm8 = _mm256_broadcast_sd((double const *)&zero); ymm0 = _mm256_unpackhi_pd(ymm10, ymm11); ymm1 = _mm256_unpackhi_pd(ymm12, ymm13); ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); - ymm9 = _mm256_broadcast_sd((double const *)&zero); _mm256_storeu_pd((double *)(pbuff + p_lda * 4), ymm6); _mm256_storeu_pd((double *)(pbuff + p_lda * 5), ymm7); @@ -3611,32 +5064,19 @@ BLIS_INLINE void bli_dtrsm_small_pack ymm11 = _mm256_loadu_pd((double const *)(inbuf + cs_a * 5 + 4)); ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); - ymm5 = _mm256_broadcast_sd((double const *)&zero); - ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); - ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); - ymm1 = _mm256_broadcast_sd((double const *)&zero); - ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); - _mm_storeu_pd((double *)(pbuff + 4), _mm256_extractf128_pd(ymm6,0)); - _mm_storeu_pd((double *)(pbuff + 4 + p_lda), _mm256_extractf128_pd(ymm7,0)); - _mm_storeu_pd((double *)(pbuff + 4 + p_lda*2), _mm256_extractf128_pd(ymm8,0)); - _mm_storeu_pd((double *)(pbuff + 4 + p_lda*3), _mm256_extractf128_pd(ymm9,0)); + _mm_storeu_pd((double *)(pbuff + 4), _mm256_castpd256_pd128(ymm4)); + _mm_storeu_pd((double *)(pbuff + 4 + p_lda), _mm256_castpd256_pd128(ymm0)); + _mm_storeu_pd((double *)(pbuff + 4 + p_lda*2), _mm256_extractf128_pd(ymm4,1)); + _mm_storeu_pd((double *)(pbuff + 4 + p_lda*3), _mm256_extractf128_pd(ymm0,1)); ymm4 = _mm256_unpacklo_pd(ymm10, ymm11); - ymm5 = _mm256_broadcast_sd((double const *)&zero); - - ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); - ymm8 = _mm256_broadcast_sd((double const *)&zero); ymm0 = _mm256_unpackhi_pd(ymm10, ymm11); - ymm1 = _mm256_broadcast_sd((double const *)&zero); - ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); - ymm9 = _mm256_broadcast_sd((double const *)&zero); - _mm_storeu_pd((double *)(pbuff + p_lda * 4 + 4), _mm256_extractf128_pd(ymm6,0)); - _mm_storeu_pd((double *)(pbuff + p_lda * 5 + 4), _mm256_extractf128_pd(ymm7,0)); + _mm_storeu_pd((double *)(pbuff + p_lda * 4 + 4), _mm256_castpd256_pd128(ymm4)); + _mm_storeu_pd((double *)(pbuff + p_lda * 5 + 4), _mm256_castpd256_pd128(ymm0)); inbuf += mr*cs_a; pbuff += mr; } @@ -3740,7 +5180,7 @@ BLIS_INLINE void dtrsm_small_pack_diag_element if(is_eight){ _mm256_store_pd((double *)(d11_pack + 4), ymm5); }else{ - _mm_storeu_pd((double *)(d11_pack + 4), _mm256_extractf128_pd(ymm5,0)); + _mm_storeu_pd((double *)(d11_pack + 4), _mm256_castpd256_pd128(ymm5)); } } @@ -3832,14 +5272,16 @@ err_t bli_trsm_small case BLIS_FLOAT: case BLIS_SCOMPLEX: { - if(m > 1000 || n > 1000) { + bool nt = bli_thread_get_is_parallel(); + if((nt == 0) && (m > 1000 || n > 1000)) { return BLIS_NOT_YET_IMPLEMENTED; } break; } case BLIS_DCOMPLEX: { - if(m > 500 || n > 500) { + bool nt = bli_thread_get_is_parallel(); + if((nt == 0) && (m > 500 || n > 500)) { return BLIS_NOT_YET_IMPLEMENTED; } break; @@ -3920,6 +5362,11 @@ err_t bli_trsm_small_mt d_mr = 8,d_nr = 6; break; } + case BLIS_DCOMPLEX: + { + d_mr = 4,d_nr = 3; + break; + } default: { return BLIS_NOT_YET_IMPLEMENTED; @@ -3934,7 +5381,7 @@ err_t bli_trsm_small_mt // If dynamic-threading is enabled, calculate optimum number // of threads. // rntm will be updated with optimum number of threads. - if( bli_obj_is_double(b)) + if( bli_obj_is_double(b) ) { bli_nthreads_optimum(a, b, b, BLIS_TRSM, &rntm); } @@ -4291,7 +5738,7 @@ BLIS_INLINE err_t ztrsm_AuXB_ref /*get the dcomplex mul answer into register*/\ ymm15 = _mm256_hsub_pd(ymm17, ymm14);\ ymm8 = _mm256_sub_pd(ymm15,ymm8);\ - xmm5 = _mm256_extractf128_pd(ymm8, 0);\ + xmm5 = _mm256_castpd256_pd128(ymm8);\ /*store dcomplex elements*/\ _mm_storeu_pd((double *)(b11 + cs_b * 0), xmm5);\ } @@ -4329,9 +5776,9 @@ BLIS_INLINE err_t ztrsm_AuXB_ref ymm14 = _mm256_mul_pd(ymm1, ymm14);\ ymm15 = _mm256_hsub_pd(ymm17, ymm14);\ ymm9 = _mm256_sub_pd(ymm15,ymm9);\ - xmm4 = _mm256_extractf128_pd(ymm8, 0);\ + xmm4 = _mm256_castpd256_pd128(ymm8);\ _mm_storeu_pd((double *)(b11 + cs_b * 0), xmm4);\ - xmm5 = _mm256_extractf128_pd(ymm9, 0);\ + xmm5 = _mm256_castpd256_pd128(ymm9);\ _mm_storeu_pd((double *)(b11 + cs_b * 1), xmm5);\ } @@ -4688,6 +6135,258 @@ BLIS_INLINE err_t ztrsm_AuXB_ref ymm10 = _mm256_addsub_pd(ymm10, ymm6);\ } +#define BLIS_ZTRSM_SMALL_GEMM_2mx2n(a10,b01,cs_b,p_lda,k_iter){\ + double *tptr = (double *)b01;\ + if(conjtransa) {\ + ymm18 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0);\ + for(k = 0; k< k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + ymm0 = _mm256_loadu_pd((double const *)(a10));\ + ymm0 = _mm256_mul_pd(ymm0, ymm18);\ + \ + ymm1 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 0));\ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 0 + 1));\ + \ + ymm8 = _mm256_fmadd_pd(ymm0, ymm1, ymm8);\ + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4);\ + \ + ymm1 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 1));\ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 1 + 1));\ + \ + ymm9 = _mm256_fmadd_pd(ymm0, ymm1, ymm9);\ + ymm5 = _mm256_fmadd_pd(ymm0, ymm2, ymm5);\ + \ + tptr += 2; /*move to next row of B*/\ + a10 += p_lda;/*pointer math to calculate next block of A for GEMM*/\ + }\ + }\ + else {\ + for(k = 0; k< k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + ymm0 = _mm256_loadu_pd((double const *)(a10));\ + \ + ymm1 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 0));\ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 0 + 1));\ + \ + ymm8 = _mm256_fmadd_pd(ymm0, ymm1, ymm8);\ + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4);\ + \ + ymm1 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 1));\ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 1 + 1));\ + \ + ymm9 = _mm256_fmadd_pd(ymm0, ymm1, ymm9);\ + ymm5 = _mm256_fmadd_pd(ymm0, ymm2, ymm5);\ + \ + tptr += 2; /*move to next row of B*/\ + a10 += p_lda;/*pointer math to calculate next block of A for GEMM*/\ + }\ + }\ + ymm4 = _mm256_permute_pd(ymm4, 0x5);\ + ymm5 = _mm256_permute_pd(ymm5, 0x5);\ + ymm8 = _mm256_addsub_pd(ymm8, ymm4);\ + ymm9 = _mm256_addsub_pd(ymm9, ymm5);\ +} + +#define BLIS_ZTRSM_SMALL_GEMM_2mx1n(a10,b01,cs_b,p_lda,k_iter){\ + double *tptr = (double *)b01;\ + if(conjtransa) {\ + ymm18 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0);\ + for(k = 0; k< k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + ymm0 = _mm256_loadu_pd((double const *)(a10));\ + ymm0 = _mm256_mul_pd(ymm0, ymm18);\ + \ + ymm1 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 0));\ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 0 + 1));\ + \ + ymm8 = _mm256_fmadd_pd(ymm0, ymm1, ymm8);\ + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4);\ + \ + tptr += 2; /*move to next row of B*/\ + a10 += p_lda;/*pointer math to calculate next block of A for GEMM*/\ + }\ + }\ + else {\ + for(k = 0; k< k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + ymm0 = _mm256_loadu_pd((double const *)(a10));\ + \ + ymm1 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 0));\ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 0 + 1));\ + \ + ymm8 = _mm256_fmadd_pd(ymm0, ymm1, ymm8);\ + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4);\ + \ + tptr += 2; /*move to next row of B*/\ + a10 += p_lda;/*pointer math to calculate next block of A for GEMM*/\ + }\ + }\ + ymm4 = _mm256_permute_pd(ymm4, 0x5);\ + ymm8 = _mm256_addsub_pd(ymm8, ymm4);\ +} + +#define BLIS_ZTRSM_SMALL_GEMM_1mx3n(a10,b01,cs_b,p_lda,k_iter){\ + double *tptr = (double *)b01;\ + if(conjtransa) {\ + ymm18 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0);\ + for(k = 0; k< k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + xmm4 = _mm_loadu_pd((double const *)(a10));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm4, 0); \ + ymm0 = _mm256_mul_pd(ymm0, ymm18);\ + \ + ymm1 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 0));\ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 0 + 1));\ + \ + ymm8 = _mm256_fmadd_pd(ymm0, ymm1, ymm8);\ + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4);\ + \ + ymm1 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 1));\ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 1 + 1));\ + \ + ymm9 = _mm256_fmadd_pd(ymm0, ymm1, ymm9);\ + ymm5 = _mm256_fmadd_pd(ymm0, ymm2, ymm5);\ + \ + ymm1 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 2));\ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 2 + 1));\ + \ + ymm10 = _mm256_fmadd_pd(ymm0, ymm1, ymm10);\ + ymm6 = _mm256_fmadd_pd(ymm0, ymm2, ymm6);\ + \ + tptr += 2; /*move to next row of B*/\ + a10 += p_lda;/*pointer math to calculate next block of A for GEMM*/\ + }\ + }\ + else {\ + for(k = 0; k< k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + xmm4 = _mm_loadu_pd((double const *)(a10));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm4, 0); \ + \ + ymm1 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 0));\ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 0 + 1));\ + \ + ymm8 = _mm256_fmadd_pd(ymm0, ymm1, ymm8);\ + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4);\ + \ + ymm1 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 1));\ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 1 + 1));\ + \ + ymm9 = _mm256_fmadd_pd(ymm0, ymm1, ymm9);\ + ymm5 = _mm256_fmadd_pd(ymm0, ymm2, ymm5);\ + \ + ymm1 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 2));\ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 2 + 1));\ + \ + ymm10 = _mm256_fmadd_pd(ymm0, ymm1, ymm10);\ + ymm6 = _mm256_fmadd_pd(ymm0, ymm2, ymm6);\ + \ + tptr += 2; /*move to next row of B*/\ + a10 += p_lda;/*pointer math to calculate next block of A for GEMM*/\ + }\ + }\ + ymm4 = _mm256_permute_pd(ymm4, 0x5);\ + ymm5 = _mm256_permute_pd(ymm5, 0x5);\ + ymm6 = _mm256_permute_pd(ymm6, 0x5);\ + ymm8 = _mm256_addsub_pd(ymm8, ymm4);\ + ymm9 = _mm256_addsub_pd(ymm9, ymm5);\ + ymm10 = _mm256_addsub_pd(ymm10, ymm6);\ +} + +#define BLIS_ZTRSM_SMALL_GEMM_1mx2n(a10,b01,cs_b,p_lda,k_iter){\ + double *tptr = (double *)b01;\ + if(conjtransa) {\ + ymm18 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0);\ + for(k = 0; k< k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + xmm4 = _mm_loadu_pd((double const *)(a10));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm4, 0); \ + ymm0 = _mm256_mul_pd(ymm0, ymm18);\ + \ + ymm1 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 0));\ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 0 + 1));\ + \ + ymm8 = _mm256_fmadd_pd(ymm0, ymm1, ymm8);\ + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4);\ + \ + ymm1 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 1));\ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 1 + 1));\ + \ + ymm9 = _mm256_fmadd_pd(ymm0, ymm1, ymm9);\ + ymm5 = _mm256_fmadd_pd(ymm0, ymm2, ymm5);\ + \ + tptr += 2; /*move to next row of B*/\ + a10 += p_lda;/*pointer math to calculate next block of A for GEMM*/\ + }\ + }\ + else {\ + for(k = 0; k< k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + xmm4 = _mm_loadu_pd((double const *)(a10));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm4, 0); \ + \ + ymm1 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 0));\ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 0 + 1));\ + \ + ymm8 = _mm256_fmadd_pd(ymm0, ymm1, ymm8);\ + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4);\ + \ + ymm1 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 1));\ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 1 + 1));\ + \ + ymm9 = _mm256_fmadd_pd(ymm0, ymm1, ymm9);\ + ymm5 = _mm256_fmadd_pd(ymm0, ymm2, ymm5);\ + \ + tptr += 2; /*move to next row of B*/\ + a10 += p_lda;/*pointer math to calculate next block of A for GEMM*/\ + }\ + }\ + ymm4 = _mm256_permute_pd(ymm4, 0x5);\ + ymm5 = _mm256_permute_pd(ymm5, 0x5);\ + ymm8 = _mm256_addsub_pd(ymm8, ymm4);\ + ymm9 = _mm256_addsub_pd(ymm9, ymm5);\ +} + +#define BLIS_ZTRSM_SMALL_GEMM_1mx1n(a10,b01,cs_b,p_lda,k_iter){\ + double *tptr = (double *)b01;\ + if(conjtransa) {\ + ymm18 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0);\ + for(k = 0; k< k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + xmm4 = _mm_loadu_pd((double const *)(a10));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm4, 0); \ + ymm0 = _mm256_mul_pd(ymm0, ymm18);\ + \ + ymm1 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 0));\ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 0 + 1));\ + \ + ymm8 = _mm256_fmadd_pd(ymm0, ymm1, ymm8);\ + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4);\ + \ + tptr += 2; /*move to next row of B*/\ + a10 += p_lda;/*pointer math to calculate next block of A for GEMM*/\ + }\ + }\ + else {\ + for(k = 0; k< k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + xmm4 = _mm_loadu_pd((double const *)(a10));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm4, 0); \ + \ + ymm1 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 0));\ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 0 + 1));\ + \ + ymm8 = _mm256_fmadd_pd(ymm0, ymm1, ymm8);\ + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4);\ + \ + tptr += 2; /*move to next row of B*/\ + a10 += p_lda;/*pointer math to calculate next block of A for GEMM*/\ + }\ + }\ + ymm4 = _mm256_permute_pd(ymm4, 0x5);\ + ymm8 = _mm256_addsub_pd(ymm8, ymm4);\ +} + /** * Performs GEMM operation. * Four elements of column in ymm0, ymm1. @@ -4854,7 +6553,7 @@ BLIS_INLINE err_t ztrsm_AuXB_ref ymm12 = _mm256_sub_pd(ymm15,ymm12);\ \ _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm8);\ - xmm5 = _mm256_extractf128_pd(ymm12, 0);\ + xmm5 = _mm256_castpd256_pd128(ymm12);\ _mm_storeu_pd((double *)(b11 + cs_b * 0 + 2), xmm5);\ } @@ -4910,9 +6609,9 @@ BLIS_INLINE err_t ztrsm_AuXB_ref \ _mm256_storeu_pd((double *)(b11), ymm8);\ _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm9);\ - xmm4 = _mm256_extractf128_pd(ymm12, 0);\ + xmm4 = _mm256_castpd256_pd128(ymm12);\ _mm_storeu_pd((double *)(b11 + cs_b * 0 + 2), xmm4);\ - xmm5 = _mm256_extractf128_pd(ymm13, 0);\ + xmm5 = _mm256_castpd256_pd128(ymm13);\ _mm_storeu_pd((double *)(b11 + cs_b * 1 + 2), xmm5);\ } @@ -5117,6 +6816,136 @@ BLIS_INLINE err_t ztrsm_AuXB_ref ymm6 = _mm256_addsub_pd(ymm6, ymm11);\ } +#define BLIS_ZTRSM_SMALL_GEMM_2nx3m(a01,b10,cs_b,p_lda,k_iter) {\ + double *tptr = (double *)a01;\ + if(conjtransa) {\ + ymm18 = _mm256_set_pd(-1.0, -1.0, -1.0, -1.0);\ + for(k = 0; k< k_iter; k++) \ + { \ + ymm0 = _mm256_loadu_pd((double const *)(b10)); \ + xmm5 = _mm_loadu_pd((double const *)(b10 + 2));\ + ymm1 = _mm256_insertf128_pd(ymm1, xmm5, 0); \ + \ + _mm_prefetch((char*)( b10 + 4*cs_b), _MM_HINT_T0); \ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0)); \ + ymm12 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0 + 1)); \ + ymm12 = _mm256_mul_pd(ymm12, ymm18);\ + \ + ymm3 = _mm256_fmadd_pd(ymm0, ymm2, ymm3);\ + ymm4 = _mm256_fmadd_pd(ymm1, ymm2, ymm4);\ + ymm8 = _mm256_fmadd_pd(ymm0, ymm12, ymm8);\ + ymm9 = _mm256_fmadd_pd(ymm1, ymm12, ymm9);\ + \ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 1)); \ + ymm12 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 1 + 1)); \ + ymm12 = _mm256_mul_pd(ymm12, ymm18);\ + \ + ymm5 = _mm256_fmadd_pd(ymm0, ymm2, ymm5);\ + ymm6 = _mm256_fmadd_pd(ymm1, ymm2, ymm6);\ + ymm10 = _mm256_fmadd_pd(ymm0, ymm12, ymm10);\ + ymm11 = _mm256_fmadd_pd(ymm1, ymm12, ymm11);\ + \ + tptr += 2; \ + b10 += cs_b; \ + }\ + }\ + else {\ + for(k = 0; k< k_iter; k++) \ + { \ + ymm0 = _mm256_loadu_pd((double const *)(b10)); \ + xmm5 = _mm_loadu_pd((double const *)(b10 + 2));\ + ymm1 = _mm256_insertf128_pd(ymm1, xmm5, 0); \ + \ + _mm_prefetch((char*)( b10 + 4*cs_b), _MM_HINT_T0); \ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0)); \ + ymm12 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0 + 1)); \ + \ + ymm3 = _mm256_fmadd_pd(ymm0, ymm2, ymm3);\ + ymm4 = _mm256_fmadd_pd(ymm1, ymm2, ymm4);\ + ymm8 = _mm256_fmadd_pd(ymm0, ymm12, ymm8);\ + ymm9 = _mm256_fmadd_pd(ymm1, ymm12, ymm9);\ + \ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 1)); \ + ymm12 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 1 + 1)); \ + \ + ymm5 = _mm256_fmadd_pd(ymm0, ymm2, ymm5);\ + ymm6 = _mm256_fmadd_pd(ymm1, ymm2, ymm6);\ + ymm10 = _mm256_fmadd_pd(ymm0, ymm12, ymm10);\ + ymm11 = _mm256_fmadd_pd(ymm1, ymm12, ymm11);\ + \ + tptr += 2; \ + b10 += cs_b; \ + }\ + }\ + ymm8 = _mm256_permute_pd(ymm8, 0x5);\ + ymm9 = _mm256_permute_pd(ymm9, 0x5);\ + ymm10 = _mm256_permute_pd(ymm10, 0x5);\ + ymm11 = _mm256_permute_pd(ymm11, 0x5);\ + ymm3 = _mm256_addsub_pd(ymm3, ymm8);\ + ymm4 = _mm256_addsub_pd(ymm4, ymm9);\ + ymm5 = _mm256_addsub_pd(ymm5, ymm10);\ + ymm6 = _mm256_addsub_pd(ymm6, ymm11);\ +} + +#define BLIS_ZTRSM_SMALL_GEMM_2nx1m(a01,b10,cs_b,p_lda,k_iter){\ + double *tptr = (double *)a01;\ + if(conjtransa) {\ + ymm18 = _mm256_set_pd(-1.0, -1.0, -1.0, -1.0);\ + for(k = 0; k< k_iter; k++) \ + { \ + xmm5 = _mm_loadu_pd((double const *)(b10));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); \ + \ + _mm_prefetch((char*)( b10 + 2*cs_b), _MM_HINT_T0); \ + ymm1 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0)); \ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0 + 1)); \ + ymm2 = _mm256_mul_pd(ymm2, ymm18);\ + \ + ymm3 = _mm256_fmadd_pd(ymm0, ymm1, ymm3);\ + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4);\ + \ + \ + ymm1 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 1)); \ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 1 + 1)); \ + ymm2 = _mm256_mul_pd(ymm2, ymm18);\ + \ + ymm5 = _mm256_fmadd_pd(ymm0, ymm1, ymm5);\ + ymm6 = _mm256_fmadd_pd(ymm0, ymm2, ymm6);\ + \ + tptr += 2; \ + b10 += cs_b; \ + }\ + }\ + else {\ + for(k = 0; k< k_iter; k++) \ + { \ + xmm5 = _mm_loadu_pd((double const *)(b10));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); \ + \ + _mm_prefetch((char*)( b10 + 2*cs_b), _MM_HINT_T0); \ + ymm1 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0)); \ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0 + 1)); \ + \ + ymm3 = _mm256_fmadd_pd(ymm0, ymm1, ymm3);\ + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4);\ + \ + \ + ymm1 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 1)); \ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 1 + 1)); \ + \ + ymm5 = _mm256_fmadd_pd(ymm0, ymm1, ymm5);\ + ymm6 = _mm256_fmadd_pd(ymm0, ymm2, ymm6);\ + \ + tptr += 2; \ + b10 += cs_b; \ + }\ + }\ + ymm4 = _mm256_permute_pd(ymm4, 0x5);\ + ymm6 = _mm256_permute_pd(ymm6, 0x5);\ + ymm3 = _mm256_addsub_pd(ymm3, ymm4);\ + ymm5 = _mm256_addsub_pd(ymm5, ymm6);\ +} + /** * Performs GEMM operation * ymm0 holds 2 elements of a column. @@ -5249,59 +7078,62 @@ BLIS_INLINE err_t ztrsm_AuXB_ref * 3 elements of a columns get held by ymm0(2 element) * and xmm5 (1 element). */ -#define BLIS_ZTRSM_SMALL_GEMM_1nx3m(a01,b10,cs_b,p_lda,k_iter) {\ + #define BLIS_ZTRSM_SMALL_GEMM_1nx3m(a01,b10,cs_b,p_lda,k_iter) {\ double *tptr = (double *)a01;\ if(conjtransa) {\ ymm18 = _mm256_set_pd(-1.0, -1.0, -1.0, -1.0);\ - for(k = 0; k< k_iter; k++) \ + for(k = 0; k < k_iter; k++)\ {\ - ymm0 = _mm256_loadu_pd((double const *)(b10)); \ - /*ymm1 = _mm256_loadu_pd((double const *)(b10 + 2));*/\ + ymm0 = _mm256_loadu_pd((double const *)b10);\ xmm5 = _mm_loadu_pd((double const *)(b10 + 2));\ - ymm1 = _mm256_insertf128_pd(ymm1, xmm5, 0);\ + ymm1 = _mm256_insertf128_pd(ymm1, xmm5, 0); \ \ _mm_prefetch((char*)( b10 + 4*cs_b), _MM_HINT_T0); \ - ymm2 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0)); \ - ymm5 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0 + 1)); \ - ymm5 = _mm256_mul_pd(ymm5, ymm18);\ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0));\ + ymm7 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0 + 1));\ + ymm7 = _mm256_mul_pd(ymm7, ymm18);\ + /*dcomplex multiplication and substraction*/\ \ ymm3 = _mm256_fmadd_pd(ymm0, ymm2, ymm3);\ - ymm6 = _mm256_fmadd_pd(ymm0, ymm2, ymm6);\ - ymm4 = _mm256_fmadd_pd(ymm1, ymm5, ymm4);\ - ymm7 = _mm256_fmadd_pd(ymm1, ymm5, ymm7);\ + ymm4 = _mm256_fmadd_pd(ymm1, ymm2, ymm4);\ + ymm5 = _mm256_fmadd_pd(ymm0, ymm7, ymm5);\ + ymm6 = _mm256_fmadd_pd(ymm1, ymm7, ymm6);\ + /*dcomplex multiplication and substraction*/\ \ tptr += 2;\ b10 += cs_b;\ }\ }\ else {\ - for(k = 0; k< k_iter; k++) \ + for(k = 0; k < k_iter; k++)\ {\ - ymm0 = _mm256_loadu_pd((double const *)(b10)); \ - /*ymm1 = _mm256_loadu_pd((double const *)(b10 + 2));*/\ + ymm0 = _mm256_loadu_pd((double const *)b10);\ xmm5 = _mm_loadu_pd((double const *)(b10 + 2));\ - ymm1 = _mm256_insertf128_pd(ymm1, xmm5, 0);\ + ymm1 = _mm256_insertf128_pd(ymm1, xmm5, 0); \ \ _mm_prefetch((char*)( b10 + 4*cs_b), _MM_HINT_T0); \ - ymm2 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0)); \ - ymm5 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0 + 1)); \ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0));\ + ymm7 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0 + 1));\ + /*dcomplex multiplication and substraction*/\ \ ymm3 = _mm256_fmadd_pd(ymm0, ymm2, ymm3);\ - ymm6 = _mm256_fmadd_pd(ymm0, ymm2, ymm6);\ - ymm4 = _mm256_fmadd_pd(ymm1, ymm5, ymm4);\ - ymm7 = _mm256_fmadd_pd(ymm1, ymm5, ymm7);\ + ymm4 = _mm256_fmadd_pd(ymm1, ymm2, ymm4);\ + ymm5 = _mm256_fmadd_pd(ymm0, ymm7, ymm5);\ + ymm6 = _mm256_fmadd_pd(ymm1, ymm7, ymm6);\ + /*ymm3 = _mm256_add_pd(ymm15, ymm3);*/\ + /*dcomplex multiplication and substraction*/\ \ tptr += 2;\ b10 += cs_b;\ }\ }\ + ymm5 = _mm256_permute_pd(ymm5, 0x5);\ ymm6 = _mm256_permute_pd(ymm6, 0x5);\ - ymm7 = _mm256_permute_pd(ymm7, 0x5);\ - ymm3 = _mm256_addsub_pd(ymm3, ymm6);\ - ymm4 = _mm256_addsub_pd(ymm5, ymm7);\ +\ + ymm3 = _mm256_addsub_pd(ymm3, ymm5);\ + ymm4 = _mm256_addsub_pd(ymm4, ymm6);\ } - /** * Performs GEMM operation. * 1 elements of a column are kept in ymm0. @@ -5345,7 +7177,7 @@ BLIS_INLINE err_t ztrsm_AuXB_ref }\ }\ ymm4 = _mm256_permute_pd(ymm4, 0x5);\ - ymm3 = _mm256_addsub_pd(ymm3, ymm4);\ + ymm3 = _mm256_addsub_pd(ymm3, ymm4);\ } @@ -5390,7 +7222,7 @@ BLIS_INLINE err_t ztrsm_AuXB_ref }\ }\ ymm4 = _mm256_permute_pd(ymm4, 0x5);\ - ymm3 = _mm256_addsub_pd(ymm3, ymm4);\ + ymm3 = _mm256_addsub_pd(ymm3, ymm4);\ } /** @@ -5488,6 +7320,184 @@ BLIS_INLINE err_t ztrsm_AuXB_ref ymm8 = _mm256_addsub_pd(ymm8, ymm15);\ } +#define BLIS_ZTRSM_SMALL_GEMM_3nx3m(a01,b10,cs_b,p_lda,k_iter) {\ + double *tptr = (double *)a01;\ + if(conjtransa) {\ + ymm18 = _mm256_set_pd(-1.0, -1.0, -1.0, -1.0);\ + for(k = 0; k< k_iter; k++) \ + { \ + ymm0 = _mm256_loadu_pd((double const *)(b10)); \ + xmm5 = _mm_loadu_pd((double const *)(b10 + 2));\ + ymm1 = _mm256_insertf128_pd(ymm1, xmm5, 0); \ + \ + _mm_prefetch((char*)( b10 + 4*cs_b), _MM_HINT_T0); \ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0));\ + ymm9 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0 + 1));\ + ymm9 = _mm256_mul_pd(ymm9, ymm18);\ + \ + ymm3 = _mm256_fmadd_pd(ymm0, ymm2, ymm3);\ + ymm4 = _mm256_fmadd_pd(ymm1, ymm2, ymm4);\ + ymm10 = _mm256_fmadd_pd(ymm0, ymm9, ymm10);\ + ymm11 = _mm256_fmadd_pd(ymm1, ymm9, ymm11);\ + \ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 1)); \ + ymm9 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 1 + 1)); \ + ymm9 = _mm256_mul_pd(ymm9, ymm18);\ + \ + ymm5 = _mm256_fmadd_pd(ymm0, ymm2, ymm5);\ + ymm6 = _mm256_fmadd_pd(ymm1, ymm2, ymm6);\ + ymm12 = _mm256_fmadd_pd(ymm0, ymm9, ymm12);\ + ymm13 = _mm256_fmadd_pd(ymm1, ymm9, ymm13);\ + \ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 2)); \ + ymm9 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 2 + 1)); \ + ymm9 = _mm256_mul_pd(ymm9, ymm18);\ + \ + ymm7 = _mm256_fmadd_pd(ymm0, ymm2, ymm7);\ + ymm8 = _mm256_fmadd_pd(ymm1, ymm2, ymm8);\ + ymm14 = _mm256_fmadd_pd(ymm0, ymm9, ymm14);\ + ymm15 = _mm256_fmadd_pd(ymm1, ymm9, ymm15);\ + \ + tptr += 2; \ + b10 += cs_b; \ + }\ + }\ + else {\ + for(k = 0; k< k_iter; k++) \ + { \ + ymm0 = _mm256_loadu_pd((double const *)(b10)); \ + xmm5 = _mm_loadu_pd((double const *)(b10 + 2));\ + ymm1 = _mm256_insertf128_pd(ymm1, xmm5, 0); \ + \ + _mm_prefetch((char*)( b10 + 4*cs_b), _MM_HINT_T0); \ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0));\ + ymm9 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0 + 1));\ + \ + ymm3 = _mm256_fmadd_pd(ymm0, ymm2, ymm3);\ + ymm4 = _mm256_fmadd_pd(ymm1, ymm2, ymm4);\ + ymm10 = _mm256_fmadd_pd(ymm0, ymm9, ymm10);\ + ymm11 = _mm256_fmadd_pd(ymm1, ymm9, ymm11);\ + \ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 1)); \ + ymm9 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 1 + 1)); \ + \ + ymm5 = _mm256_fmadd_pd(ymm0, ymm2, ymm5);\ + ymm6 = _mm256_fmadd_pd(ymm1, ymm2, ymm6);\ + ymm12 = _mm256_fmadd_pd(ymm0, ymm9, ymm12);\ + ymm13 = _mm256_fmadd_pd(ymm1, ymm9, ymm13);\ + \ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 2)); \ + ymm9 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 2 + 1)); \ + \ + ymm7 = _mm256_fmadd_pd(ymm0, ymm2, ymm7);\ + ymm8 = _mm256_fmadd_pd(ymm1, ymm2, ymm8);\ + ymm14 = _mm256_fmadd_pd(ymm0, ymm9, ymm14);\ + ymm15 = _mm256_fmadd_pd(ymm1, ymm9, ymm15);\ + \ + tptr += 2; \ + b10 += cs_b; \ + }\ + }\ + ymm10 = _mm256_permute_pd(ymm10, 0x5);\ + ymm11 = _mm256_permute_pd(ymm11, 0x5);\ + ymm12 = _mm256_permute_pd(ymm12, 0x5);\ + ymm13 = _mm256_permute_pd(ymm13, 0x5);\ + ymm14 = _mm256_permute_pd(ymm14, 0x5);\ + ymm15 = _mm256_permute_pd(ymm15, 0x5);\ +\ + ymm3 = _mm256_addsub_pd(ymm3, ymm10);\ + ymm4 = _mm256_addsub_pd(ymm4, ymm11);\ + ymm5 = _mm256_addsub_pd(ymm5, ymm12);\ + ymm6 = _mm256_addsub_pd(ymm6, ymm13);\ + ymm7 = _mm256_addsub_pd(ymm7, ymm14);\ + ymm8 = _mm256_addsub_pd(ymm8, ymm15);\ +} + +#define BLIS_ZTRSM_SMALL_GEMM_3nx1m(a01,b10,cs_b,p_lda,k_iter) {\ + double *tptr = (double *)a01;\ + if(conjtransa) {\ + ymm18 = _mm256_set_pd(-1.0, -1.0, -1.0, -1.0);\ + for(k = 0; k< k_iter; k++) \ + {\ + xmm5 = _mm_loadu_pd((double const *)(b10));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); \ + \ + _mm_prefetch((char*)( b10 + 2*cs_b), _MM_HINT_T0); \ + ymm4 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0)); \ + ymm6 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0 + 1)); \ + ymm6 = _mm256_mul_pd(ymm6, ymm18);\ + /*dcomplex multiplication and substraction*/\ + \ + ymm3 = _mm256_fmadd_pd(ymm0, ymm4, ymm3);\ + ymm8 = _mm256_fmadd_pd(ymm0, ymm6, ymm8);\ + \ + ymm4 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 1)); \ + ymm6 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 1 + 1)); \ + ymm6 = _mm256_mul_pd(ymm6, ymm18);\ + \ + /*dcomplex multiplication and substraction*/\ + \ + ymm5 = _mm256_fmadd_pd(ymm0, ymm4, ymm5);\ + ymm9 = _mm256_fmadd_pd(ymm0, ymm6, ymm9);\ + \ + ymm4 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 2)); \ + ymm6 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 2 + 1)); \ + ymm6 = _mm256_mul_pd(ymm6, ymm18);\ + \ + /*dcomplex multiplication and substraction*/\ + \ + ymm7 = _mm256_fmadd_pd(ymm0, ymm4, ymm7);\ + ymm10 = _mm256_fmadd_pd(ymm0, ymm6, ymm10);\ + \ + tptr += 2; \ + b10 += cs_b; \ + }\ + }\ + else {\ + for(k = 0; k< k_iter; k++) \ + {\ + xmm5 = _mm_loadu_pd((double const *)(b10));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); \ + \ + _mm_prefetch((char*)( b10 + 2*cs_b), _MM_HINT_T0); \ + ymm4 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0)); \ + ymm6 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0 + 1)); \ + /*dcomplex multiplication and substraction*/\ + \ + ymm3 = _mm256_fmadd_pd(ymm0, ymm4, ymm3);\ + ymm8 = _mm256_fmadd_pd(ymm0, ymm6, ymm8);\ + /*ymm3 = _mm256_add_pd(ymm15, ymm3);*/\ + \ + ymm4 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 1)); \ + ymm6 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 1 + 1)); \ + \ + /*dcomplex multiplication and substraction*/\ + \ + ymm5 = _mm256_fmadd_pd(ymm0, ymm4, ymm5);\ + ymm9 = _mm256_fmadd_pd(ymm0, ymm6, ymm9);\ + /*ymm5 = _mm256_add_pd(ymm15, ymm5);*/\ + \ + ymm4 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 2)); \ + ymm6 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 2 + 1)); \ + \ + /*dcomplex multiplication and substraction*/\ + \ + ymm7 = _mm256_fmadd_pd(ymm0, ymm4, ymm7);\ + ymm10 = _mm256_fmadd_pd(ymm0, ymm6, ymm10);\ + /*ymm7 = _mm256_add_pd(ymm15, ymm7);*/\ + \ + tptr += 2; \ + b10 += cs_b; \ + }\ + }\ + ymm8 = _mm256_permute_pd(ymm8, 0x5);\ + ymm9 = _mm256_permute_pd(ymm9, 0x5);\ + ymm10 = _mm256_permute_pd(ymm10, 0x5);\ + ymm3 = _mm256_addsub_pd(ymm3, ymm8);\ + ymm5 = _mm256_addsub_pd(ymm5, ymm9);\ + ymm7 = _mm256_addsub_pd(ymm7, ymm10);\ +} + /** * Multiplies Alpha with 4 element of 2 columns. * ymm0 and ymm1 holds 4 elements of a column. @@ -5531,6 +7541,72 @@ BLIS_INLINE err_t ztrsm_AuXB_ref ymm6 = _mm256_sub_pd(ymm15,ymm6);\ } +#define BLIS_PRE_ZTRSM_SMALL_2x1(AlphaVal,b11,cs_b){\ + ymm16 = _mm256_broadcast_pd(( __m128d const*)(&AlphaVal));\ + \ + xmm5 = _mm_loadu_pd((double const *)(b11));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); \ + ymm18 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0);\ + \ + ymm14 = _mm256_permute_pd(ymm16, 0x5);\ + ymm14 = _mm256_mul_pd(ymm14, ymm18);\ + ymm17 = _mm256_mul_pd(ymm0, ymm16);\ + ymm14 = _mm256_mul_pd(ymm0, ymm14);\ + ymm15 = _mm256_hsub_pd(ymm17, ymm14);\ + ymm3 = _mm256_sub_pd(ymm15,ymm3);\ + \ + xmm5 = _mm_loadu_pd((double const *)(b11 +cs_b));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); \ +\ + ymm14 = _mm256_permute_pd(ymm16, 0x5);\ + ymm14 = _mm256_mul_pd(ymm14, ymm18);\ + ymm17 = _mm256_mul_pd(ymm0, ymm16);\ + ymm14 = _mm256_mul_pd(ymm0, ymm14);\ + ymm15 = _mm256_hsub_pd(ymm17, ymm14);\ + ymm5 = _mm256_sub_pd(ymm15,ymm5);\ +} + +#define BLIS_PRE_ZTRSM_SMALL_2x3(AlphaVal,b11,cs_b) {\ + ymm16 = _mm256_broadcast_pd(( __m128d const*)(&AlphaVal));\ + \ + ymm0 = _mm256_loadu_pd((double const *)(b11));\ + xmm5 = _mm_loadu_pd((double const *)(b11 + 2));\ + ymm1 = _mm256_insertf128_pd(ymm1, xmm5, 0); \ + ymm18 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0);\ + \ + ymm14 = _mm256_permute_pd(ymm16, 0x5);\ + ymm14 = _mm256_mul_pd(ymm14, ymm18);\ + ymm17 = _mm256_mul_pd(ymm0, ymm16);\ + ymm14 = _mm256_mul_pd(ymm0, ymm14);\ + ymm15 = _mm256_hsub_pd(ymm17, ymm14);\ + ymm3 = _mm256_sub_pd(ymm15,ymm3);\ + \ + ymm14 = _mm256_permute_pd(ymm16, 0x5);\ + ymm14 = _mm256_mul_pd(ymm14, ymm18);\ + ymm17 = _mm256_mul_pd(ymm1, ymm16);\ + ymm14 = _mm256_mul_pd(ymm1, ymm14);\ + ymm15 = _mm256_hsub_pd(ymm17, ymm14);\ + ymm4 = _mm256_sub_pd(ymm15,ymm4);\ + \ + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *1));\ + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b *1 + 2));\ + ymm1 = _mm256_insertf128_pd(ymm1, xmm5, 0); \ +\ + ymm14 = _mm256_permute_pd(ymm16, 0x5);\ + ymm14 = _mm256_mul_pd(ymm14, ymm18);\ + ymm17 = _mm256_mul_pd(ymm0, ymm16);\ + ymm14 = _mm256_mul_pd(ymm0, ymm14);\ + ymm15 = _mm256_hsub_pd(ymm17, ymm14);\ + ymm5 = _mm256_sub_pd(ymm15,ymm5);\ + \ + ymm14 = _mm256_permute_pd(ymm16, 0x5);\ + ymm14 = _mm256_mul_pd(ymm14, ymm18);\ + ymm17 = _mm256_mul_pd(ymm1, ymm16);\ + ymm14 = _mm256_mul_pd(ymm1, ymm14);\ + ymm15 = _mm256_hsub_pd(ymm17, ymm14);\ + ymm6 = _mm256_sub_pd(ymm15,ymm6);\ +} + /** * Multiplies Alpha with 4 element of 3 columns. * ymm0 and ymm1 holds 4 elements of a column. @@ -5592,6 +7668,102 @@ BLIS_INLINE err_t ztrsm_AuXB_ref \ } +#define BLIS_PRE_ZTRSM_SMALL_3x3(AlphaVal,b11,cs_b) {\ + ymm16 = _mm256_broadcast_pd(( __m128d const*)(&AlphaVal));\ + \ + ymm0 = _mm256_loadu_pd((double const *)(b11));\ + xmm5 = _mm_loadu_pd((double const *)(b11 + 2));\ + ymm1 = _mm256_insertf128_pd(ymm1, xmm5, 0); \ + ymm18 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0);\ + \ + ymm14 = _mm256_permute_pd(ymm16, 0x5);\ + ymm14 = _mm256_mul_pd(ymm14, ymm18);\ + ymm17 = _mm256_mul_pd(ymm0, ymm16);\ + ymm14 = _mm256_mul_pd(ymm0, ymm14);\ + ymm15 = _mm256_hsub_pd(ymm17, ymm14);\ + ymm3 = _mm256_sub_pd(ymm15,ymm3);\ + \ + ymm14 = _mm256_permute_pd(ymm16, 0x5);\ + ymm14 = _mm256_mul_pd(ymm14, ymm18);\ + ymm17 = _mm256_mul_pd(ymm1, ymm16);\ + ymm14 = _mm256_mul_pd(ymm1, ymm14);\ + ymm15 = _mm256_hsub_pd(ymm17, ymm14);\ + ymm4 = _mm256_sub_pd(ymm15,ymm4);\ + \ + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *1));\ + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b *1 + 2));\ + ymm1 = _mm256_insertf128_pd(ymm1, xmm5, 0); \ +\ + ymm14 = _mm256_permute_pd(ymm16, 0x5);\ + ymm14 = _mm256_mul_pd(ymm14, ymm18);\ + ymm17 = _mm256_mul_pd(ymm0, ymm16);\ + ymm14 = _mm256_mul_pd(ymm0, ymm14);\ + ymm15 = _mm256_hsub_pd(ymm17, ymm14);\ + ymm5 = _mm256_sub_pd(ymm15,ymm5);\ + \ + ymm14 = _mm256_permute_pd(ymm16, 0x5);\ + ymm14 = _mm256_mul_pd(ymm14, ymm18);\ + ymm17 = _mm256_mul_pd(ymm1, ymm16);\ + ymm14 = _mm256_mul_pd(ymm1, ymm14);\ + ymm15 = _mm256_hsub_pd(ymm17, ymm14);\ + ymm6 = _mm256_sub_pd(ymm15,ymm6);\ + \ + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *2));\ + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b *2 + 2));\ + ymm1 = _mm256_insertf128_pd(ymm1, xmm5, 0); \ + \ + ymm14 = _mm256_permute_pd(ymm16, 0x5);\ + ymm14 = _mm256_mul_pd(ymm14, ymm18);\ + ymm17 = _mm256_mul_pd(ymm0, ymm16);\ + ymm14 = _mm256_mul_pd(ymm0, ymm14);\ + ymm15 = _mm256_hsub_pd(ymm17, ymm14);\ + ymm7 = _mm256_sub_pd(ymm15,ymm7);\ + \ + ymm14 = _mm256_permute_pd(ymm16, 0x5);\ + ymm14 = _mm256_mul_pd(ymm14, ymm18);\ + ymm17 = _mm256_mul_pd(ymm1, ymm16);\ + ymm14 = _mm256_mul_pd(ymm1, ymm14);\ + ymm15 = _mm256_hsub_pd(ymm17, ymm14);\ + ymm8 = _mm256_sub_pd(ymm15,ymm8);\ + \ +} + +#define BLIS_PRE_ZTRSM_SMALL_3x1(AlphaVal,b11,cs_b) {\ + ymm16 = _mm256_broadcast_pd(( __m128d const*)(&AlphaVal));\ + \ + xmm5 = _mm_loadu_pd((double const *)(b11));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); \ + ymm18 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0);\ + \ + ymm14 = _mm256_permute_pd(ymm16, 0x5);\ + ymm14 = _mm256_mul_pd(ymm14, ymm18);\ + ymm17 = _mm256_mul_pd(ymm0, ymm16);\ + ymm14 = _mm256_mul_pd(ymm0, ymm14);\ + ymm15 = _mm256_hsub_pd(ymm17, ymm14);\ + ymm3 = _mm256_sub_pd(ymm15,ymm3);\ + \ + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b*1));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); \ +\ + ymm14 = _mm256_permute_pd(ymm16, 0x5);\ + ymm14 = _mm256_mul_pd(ymm14, ymm18);\ + ymm17 = _mm256_mul_pd(ymm0, ymm16);\ + ymm14 = _mm256_mul_pd(ymm0, ymm14);\ + ymm15 = _mm256_hsub_pd(ymm17, ymm14);\ + ymm5 = _mm256_sub_pd(ymm15,ymm5);\ + \ + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b*2));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); \ + \ + ymm14 = _mm256_permute_pd(ymm16, 0x5);\ + ymm14 = _mm256_mul_pd(ymm14, ymm18);\ + ymm17 = _mm256_mul_pd(ymm0, ymm16);\ + ymm14 = _mm256_mul_pd(ymm0, ymm14);\ + ymm15 = _mm256_hsub_pd(ymm17, ymm14);\ + ymm7 = _mm256_sub_pd(ymm15,ymm7);\ + \ +} + /* * Pack a block of 4xk or 3xk from input buffer into packed buffer * directly or after transpose based on input params @@ -5782,7 +7954,7 @@ BLIS_INLINE err_t ztrsm_AuXB_ref \ ymm8 = _mm256_addsub_pd(ymm8, ymm4);\ ymm12 = _mm256_addsub_pd(ymm12, ymm5);\ - ymm9 = _mm256_addsub_pd(ymm9, ymm6);\ + ymm9 = _mm256_addsub_pd(ymm9, ymm6);\ ymm13 = _mm256_addsub_pd(ymm13, ymm7);\ } @@ -6091,9 +8263,9 @@ BLIS_INLINE void bli_ztrsm_small_pack ymm7 = _mm256_permute2f128_pd(ymm0,ymm5,0x31); ymm8 = _mm256_permute2f128_pd(ymm1,ymm5,0x20); - _mm_storeu_pd((double *)(pbuff + 2), _mm256_extractf128_pd(ymm6,0)); - _mm_storeu_pd((double *)(pbuff + p_lda + 2), _mm256_extractf128_pd(ymm7,0)); - _mm_storeu_pd((double *)(pbuff + p_lda * 2 + 2), _mm256_extractf128_pd(ymm8,0)); + _mm_storeu_pd((double *)(pbuff + 2), _mm256_castpd256_pd128(ymm6)); + _mm_storeu_pd((double *)(pbuff + p_lda + 2), _mm256_castpd256_pd128(ymm7)); + _mm_storeu_pd((double *)(pbuff + p_lda * 2 + 2), _mm256_castpd256_pd128(ymm8)); inbuf += mr*cs_a; pbuff += mr; @@ -6227,7 +8399,6 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB dim_t k_iter; //determines the number of GEMM operations to be done double ones = 1.0; - double zero = 0.0; bool is_unitdiag = bli_obj_has_unit_diag(a); double AlphaVal = *(double *)AlphaObj->buffer; //value of Alpha @@ -6274,6 +8445,8 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB __m128d xmm5; + xmm5 = _mm_setzero_pd(); + /* Performs solving TRSM for 6 rows at a time from 0 to n/6 in steps of d_nr a. Load and pack A (a01 block), the size of packing 6x6 to 6x (n-6) @@ -6363,6 +8536,13 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB b. Towards the end TRSM output will be stored back into b11 */ + _mm_prefetch((char*)(b11 + 0 + 7), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b + 7), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + 2 * cs_b + 7), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + 3 * cs_b + 7), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + 4 * cs_b + 7), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + 5 * cs_b + 7), _MM_HINT_T0); + //extract a00 ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); @@ -6633,10 +8813,10 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB BLIS_SET_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_DTRSM_SMALL_GEMM_6nx4m(a01,b10,cs_b,p_lda,k_iter) + BLIS_DTRSM_SMALL_GEMM_6nx3m(a01,b10,cs_b,p_lda,k_iter) // Load b11 of size 4x6 and multiply with alpha - BLIS_PRE_DTRSM_SMALL_6x4(AlphaVal,b11,cs_b) + BLIS_PRE_DTRSM_SMALL_6x3(AlphaVal,b11,cs_b) ///implement TRSM/// @@ -6727,12 +8907,12 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm0); - _mm_storeu_pd((double *)b11, _mm256_extractf128_pd(ymm3,0)); - _mm_storeu_pd((double *)(b11 + cs_b), _mm256_extractf128_pd(ymm5,0)); - _mm_storeu_pd((double *)(b11 + cs_b*2), _mm256_extractf128_pd(ymm7,0)); - _mm_storeu_pd((double *)(b11 + cs_b*3), _mm256_extractf128_pd(ymm9,0)); - _mm_storeu_pd((double *)(b11 + cs_b*4), _mm256_extractf128_pd(ymm11,0)); - _mm_storeu_pd((double *)(b11 + cs_b*5), _mm256_extractf128_pd(ymm13,0)); + _mm_storeu_pd((double *)b11, _mm256_castpd256_pd128(ymm3)); + _mm_storeu_pd((double *)(b11 + cs_b), _mm256_castpd256_pd128(ymm5)); + _mm_storeu_pd((double *)(b11 + cs_b*2), _mm256_castpd256_pd128(ymm7)); + _mm_storeu_pd((double *)(b11 + cs_b*3), _mm256_castpd256_pd128(ymm9)); + _mm_storeu_pd((double *)(b11 + cs_b*4), _mm256_castpd256_pd128(ymm11)); + _mm_storeu_pd((double *)(b11 + cs_b*5), _mm256_castpd256_pd128(ymm13)); _mm_storel_pd((double *)b11 + 2, _mm256_extractf128_pd(ymm3,1)); _mm_storel_pd((double *)(b11 + cs_b + 2), _mm256_extractf128_pd(ymm5,1)); @@ -6757,10 +8937,10 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB BLIS_SET_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_DTRSM_SMALL_GEMM_6nx4m(a01,b10,cs_b,p_lda,k_iter) + BLIS_DTRSM_SMALL_GEMM_6nx2m(a01,b10,cs_b,p_lda,k_iter) // Load b11 of size 4x6 and multiply with alpha - BLIS_PRE_DTRSM_SMALL_6x4(AlphaVal,b11,cs_b) + BLIS_PRE_DTRSM_SMALL_6x2(AlphaVal,b11,cs_b) ///implement TRSM/// @@ -6851,12 +9031,12 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm0); - _mm_storeu_pd((double *)b11, _mm256_extractf128_pd(ymm3,0)); - _mm_storeu_pd((double *)(b11 + cs_b), _mm256_extractf128_pd(ymm5,0)); - _mm_storeu_pd((double *)(b11 + cs_b*2), _mm256_extractf128_pd(ymm7,0)); - _mm_storeu_pd((double *)(b11 + cs_b*3), _mm256_extractf128_pd(ymm9,0)); - _mm_storeu_pd((double *)(b11 + cs_b*4), _mm256_extractf128_pd(ymm11,0)); - _mm_storeu_pd((double *)(b11 + cs_b*5), _mm256_extractf128_pd(ymm13,0)); + _mm_storeu_pd((double *)b11, _mm256_castpd256_pd128(ymm3)); + _mm_storeu_pd((double *)(b11 + cs_b), _mm256_castpd256_pd128(ymm5)); + _mm_storeu_pd((double *)(b11 + cs_b*2), _mm256_castpd256_pd128(ymm7)); + _mm_storeu_pd((double *)(b11 + cs_b*3), _mm256_castpd256_pd128(ymm9)); + _mm_storeu_pd((double *)(b11 + cs_b*4), _mm256_castpd256_pd128(ymm11)); + _mm_storeu_pd((double *)(b11 + cs_b*5), _mm256_castpd256_pd128(ymm13)); m_remainder -= 2; i += 2; @@ -6874,10 +9054,10 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB BLIS_SET_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_DTRSM_SMALL_GEMM_6nx4m(a01,b10,cs_b,p_lda,k_iter) + BLIS_DTRSM_SMALL_GEMM_6nx1m(a01,b10,cs_b,p_lda,k_iter) // Load b11 of size 4x6 and multiply with alpha - BLIS_PRE_DTRSM_SMALL_6x4(AlphaVal,b11,cs_b) + BLIS_PRE_DTRSM_SMALL_6x1(AlphaVal,b11,cs_b) ///implement TRSM/// @@ -6968,12 +9148,12 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm0); - _mm_storel_pd((double *)b11, _mm256_extractf128_pd(ymm3,0)); - _mm_storel_pd((double *)(b11 + cs_b), _mm256_extractf128_pd(ymm5,0)); - _mm_storel_pd((double *)(b11 + cs_b*2), _mm256_extractf128_pd(ymm7,0)); - _mm_storel_pd((double *)(b11 + cs_b*3), _mm256_extractf128_pd(ymm9,0)); - _mm_storel_pd((double *)(b11 + cs_b*4), _mm256_extractf128_pd(ymm11,0)); - _mm_storel_pd((double *)(b11 + cs_b*5), _mm256_extractf128_pd(ymm13,0)); + _mm_storel_pd((double *)b11, _mm256_castpd256_pd128(ymm3)); + _mm_storel_pd((double *)(b11 + cs_b), _mm256_castpd256_pd128(ymm5)); + _mm_storel_pd((double *)(b11 + cs_b*2), _mm256_castpd256_pd128(ymm7)); + _mm_storel_pd((double *)(b11 + cs_b*3), _mm256_castpd256_pd128(ymm9)); + _mm_storel_pd((double *)(b11 + cs_b*4), _mm256_castpd256_pd128(ymm11)); + _mm_storel_pd((double *)(b11 + cs_b*5), _mm256_castpd256_pd128(ymm13)); m_remainder -= 1; i += 1; @@ -7028,21 +9208,12 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a * 5)); ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); - ymm5 = _mm256_broadcast_sd((double const *)&zero); - - ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); - ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); - ymm1 = _mm256_broadcast_sd((double const *)&zero); - ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); - - _mm_storeu_pd((double *)(ptr_a10_dup + 4), _mm256_extractf128_pd(ymm6,0)); - _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda), _mm256_extractf128_pd(ymm7,0)); - _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*2), _mm256_extractf128_pd(ymm8,0)); - _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*3), _mm256_extractf128_pd(ymm9,0)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4), _mm256_castpd256_pd128(ymm4)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda), _mm256_castpd256_pd128(ymm0)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*2), _mm256_extractf128_pd(ymm4,1)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*3), _mm256_extractf128_pd(ymm0,1)); a01 += d_nr*cs_a; ptr_a10_dup += d_nr; @@ -7301,7 +9472,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB BLIS_SET_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_DTRSM_SMALL_GEMM_4nx4m(a01,b10,cs_b,p_lda,k_iter) + BLIS_DTRSM_SMALL_GEMM_4nx3m(a01,b10,cs_b,p_lda,k_iter) ymm15 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha @@ -7365,10 +9536,10 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); - _mm_storeu_pd((double *)b11, _mm256_extractf128_pd(ymm3,0)); - _mm_storeu_pd((double *)(b11 + cs_b), _mm256_extractf128_pd(ymm5,0)); - _mm_storeu_pd((double *)(b11 + cs_b*2), _mm256_extractf128_pd(ymm7,0)); - _mm_storeu_pd((double *)(b11 + cs_b*3), _mm256_extractf128_pd(ymm9,0)); + _mm_storeu_pd((double *)b11, _mm256_castpd256_pd128(ymm3)); + _mm_storeu_pd((double *)(b11 + cs_b), _mm256_castpd256_pd128(ymm5)); + _mm_storeu_pd((double *)(b11 + cs_b*2), _mm256_castpd256_pd128(ymm7)); + _mm_storeu_pd((double *)(b11 + cs_b*3), _mm256_castpd256_pd128(ymm9)); _mm_storel_pd((double *)b11 + 2, _mm256_extractf128_pd(ymm3,1)); _mm_storel_pd((double *)(b11 + cs_b + 2), _mm256_extractf128_pd(ymm5,1)); @@ -7391,7 +9562,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB BLIS_SET_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_DTRSM_SMALL_GEMM_4nx4m(a01,b10,cs_b,p_lda,k_iter) + BLIS_DTRSM_SMALL_GEMM_4nx2m(a01,b10,cs_b,p_lda,k_iter) ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha @@ -7454,10 +9625,10 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); - _mm_storeu_pd((double *)b11, _mm256_extractf128_pd(ymm3,0)); - _mm_storeu_pd((double *)(b11 + cs_b), _mm256_extractf128_pd(ymm5,0)); - _mm_storeu_pd((double *)(b11 + cs_b*2), _mm256_extractf128_pd(ymm7,0)); - _mm_storeu_pd((double *)(b11 + cs_b*3), _mm256_extractf128_pd(ymm9,0)); + _mm_storeu_pd((double *)b11, _mm256_castpd256_pd128(ymm3)); + _mm_storeu_pd((double *)(b11 + cs_b), _mm256_castpd256_pd128(ymm5)); + _mm_storeu_pd((double *)(b11 + cs_b*2), _mm256_castpd256_pd128(ymm7)); + _mm_storeu_pd((double *)(b11 + cs_b*3), _mm256_castpd256_pd128(ymm9)); m_remainder -= 2; i += 2; @@ -7475,7 +9646,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB BLIS_SET_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_DTRSM_SMALL_GEMM_4nx4m(a01,b10,cs_b,p_lda,k_iter) + BLIS_DTRSM_SMALL_GEMM_4nx1m(a01,b10,cs_b,p_lda,k_iter) ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha @@ -7537,10 +9708,10 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); - _mm_storel_pd((b11 + cs_b * 0), _mm256_extractf128_pd(ymm3, 0)); - _mm_storel_pd((b11 + cs_b * 1), _mm256_extractf128_pd(ymm5, 0)); - _mm_storel_pd((b11 + cs_b * 2), _mm256_extractf128_pd(ymm7, 0)); - _mm_storel_pd((b11 + cs_b * 3), _mm256_extractf128_pd(ymm9, 0)); + _mm_storel_pd((b11 + cs_b * 0), _mm256_castpd256_pd128(ymm3)); + _mm_storel_pd((b11 + cs_b * 1), _mm256_castpd256_pd128(ymm5)); + _mm_storel_pd((b11 + cs_b * 2), _mm256_castpd256_pd128(ymm7)); + _mm_storel_pd((b11 + cs_b * 3), _mm256_castpd256_pd128(ymm9)); m_remainder -= 1; i += 1; @@ -7589,21 +9760,12 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a * 5)); ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); - ymm5 = _mm256_broadcast_sd((double const *)&zero); - - ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); - ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); - ymm1 = _mm256_broadcast_sd((double const *)&zero); - - ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); - _mm_storeu_pd((double *)(ptr_a10_dup + 4), _mm256_extractf128_pd(ymm6,0)); - _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda), _mm256_extractf128_pd(ymm7,0)); - _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*2), _mm256_extractf128_pd(ymm8,0)); - _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*3), _mm256_extractf128_pd(ymm9,0)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4), _mm256_castpd256_pd128(ymm4)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda), _mm256_castpd256_pd128(ymm0)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*2), _mm256_extractf128_pd(ymm4,1)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*3), _mm256_extractf128_pd(ymm0,1)); a01 += d_nr*cs_a; ptr_a10_dup += d_nr; @@ -7827,7 +9989,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB BLIS_SET_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_DTRSM_SMALL_GEMM_3nx4m(a01,b10,cs_b,p_lda,k_iter) + BLIS_DTRSM_SMALL_GEMM_3nx3m(a01,b10,cs_b,p_lda,k_iter) BLIS_PRE_DTRSM_SMALL_3N_3M(AlphaVal,b11,cs_b) @@ -7878,7 +10040,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB BLIS_SET_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_DTRSM_SMALL_GEMM_3nx4m(a01,b10,cs_b,p_lda,k_iter) + BLIS_DTRSM_SMALL_GEMM_3nx2m(a01,b10,cs_b,p_lda,k_iter) BLIS_PRE_DTRSM_SMALL_3N_2M(AlphaVal,b11,cs_b) @@ -7929,7 +10091,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB BLIS_SET_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_DTRSM_SMALL_GEMM_3nx4m(a01,b10,cs_b,p_lda,k_iter) + BLIS_DTRSM_SMALL_GEMM_3nx1m(a01,b10,cs_b,p_lda,k_iter) BLIS_PRE_DTRSM_SMALL_3N_1M(AlphaVal,b11,cs_b) @@ -8010,21 +10172,12 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a * 5)); ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); - ymm5 = _mm256_broadcast_sd((double const *)&zero); - - ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); - ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); - ymm1 = _mm256_broadcast_sd((double const *)&zero); - - ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); - _mm_storeu_pd((double *)(ptr_a10_dup + 4), _mm256_extractf128_pd(ymm6,0)); - _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda), _mm256_extractf128_pd(ymm7,0)); - _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*2), _mm256_extractf128_pd(ymm8,0)); - _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*3), _mm256_extractf128_pd(ymm9,0)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4), _mm256_castpd256_pd128(ymm4)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda), _mm256_castpd256_pd128(ymm0)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*2), _mm256_extractf128_pd(ymm4,1)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*3), _mm256_extractf128_pd(ymm0,1)); a01 += d_nr*cs_a; ptr_a10_dup += d_nr; @@ -8199,7 +10352,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB ymm5 = _mm256_setzero_pd(); ///GEMM implementation starts/// - BLIS_DTRSM_SMALL_GEMM_2nx4m(a01,b10,cs_b,p_lda,k_iter) + BLIS_DTRSM_SMALL_GEMM_2nx3m(a01,b10,cs_b,p_lda,k_iter) BLIS_PRE_DTRSM_SMALL_2N_3M(AlphaVal,b11,cs_b) @@ -8236,7 +10389,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB ymm5 = _mm256_setzero_pd(); ///GEMM implementation starts/// - BLIS_DTRSM_SMALL_GEMM_2nx4m(a01,b10,cs_b,p_lda,k_iter) + BLIS_DTRSM_SMALL_GEMM_2nx2m(a01,b10,cs_b,p_lda,k_iter) BLIS_PRE_DTRSM_SMALL_2N_2M(AlphaVal,b11,cs_b) @@ -8272,7 +10425,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB ymm5 = _mm256_setzero_pd(); ///GEMM implementation starts/// - BLIS_DTRSM_SMALL_GEMM_2nx4m(a01,b10,cs_b,p_lda,k_iter) + BLIS_DTRSM_SMALL_GEMM_2nx1m(a01,b10,cs_b,p_lda,k_iter) BLIS_PRE_DTRSM_SMALL_2N_1M(AlphaVal,b11,cs_b) @@ -8311,7 +10464,8 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB if(transa) { - for(dim_t x =0;x < p_lda;x+=d_nr) + dim_t x = 0; + for(x = 0;(x + d_nr - 1) < p_lda;x+=d_nr) { ymm0 = _mm256_loadu_pd((double const *)(a01)); ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a)); @@ -8339,24 +10493,44 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a * 5)); ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); - ymm5 = _mm256_broadcast_sd((double const *)&zero); + + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); + + _mm_storeu_pd((double *)(ptr_a10_dup + 4), _mm256_castpd256_pd128(ymm4)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda), _mm256_castpd256_pd128(ymm0)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*2), _mm256_extractf128_pd(ymm4,1)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*3), _mm256_extractf128_pd(ymm0,1)); + + a01 += d_nr*cs_a; + ptr_a10_dup += d_nr; + } + dim_t remainder_loop_count = p_lda - x; + if(remainder_loop_count >= 4) + { + ymm0 = _mm256_loadu_pd((double const *)(a01)); + ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a)); + ymm2 = _mm256_loadu_pd((double const *)(a01 + cs_a * 2)); + ymm3 = _mm256_loadu_pd((double const *)(a01 + cs_a * 3)); + + ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); - ymm1 = _mm256_broadcast_sd((double const *)&zero); + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); - _mm_storeu_pd((double *)(ptr_a10_dup + 4), _mm256_extractf128_pd(ymm6,0)); - _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda), _mm256_extractf128_pd(ymm7,0)); - _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*2), _mm256_extractf128_pd(ymm8,0)); - _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*3), _mm256_extractf128_pd(ymm9,0)); - - a01 += d_nr*cs_a; - ptr_a10_dup += d_nr; + _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda), ymm7); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*2), ymm8); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*3), ymm9); + a01 += 4*cs_a; + ptr_a10_dup += 4; + remainder_loop_count = remainder_loop_count - 4; } } else @@ -8479,7 +10653,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB ymm3 = _mm256_setzero_pd(); ///GEMM implementation starts/// - BLIS_DTRSM_SMALL_GEMM_1nx4m(a01,b10,cs_b,p_lda,k_iter) + BLIS_DTRSM_SMALL_GEMM_1nx3m(a01,b10,cs_b,p_lda,k_iter) BLIS_PRE_DTRSM_SMALL_1N_3M(AlphaVal,b11,cs_b) @@ -8506,7 +10680,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB ymm3 = _mm256_setzero_pd(); ///GEMM implementation starts/// - BLIS_DTRSM_SMALL_GEMM_1nx4m(a01,b10,cs_b,p_lda,k_iter) + BLIS_DTRSM_SMALL_GEMM_1nx2m(a01,b10,cs_b,p_lda,k_iter) BLIS_PRE_DTRSM_SMALL_1N_2M(AlphaVal,b11,cs_b) @@ -8533,7 +10707,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB ymm3 = _mm256_setzero_pd(); ///GEMM implementation starts/// - BLIS_DTRSM_SMALL_GEMM_1nx4m(a01,b10,cs_b,p_lda,k_iter) + BLIS_DTRSM_SMALL_GEMM_1nx1m(a01,b10,cs_b,p_lda,k_iter) BLIS_PRE_DTRSM_SMALL_1N_1M(AlphaVal,b11,cs_b) @@ -8674,6 +10848,8 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB_XAlB __m128d xmm5; + xmm5 = _mm_setzero_pd(); + /* Performs solving TRSM for 6 rows at a time from 0 to n/6 in steps of d_nr a. Load and pack A (a01 block), the size of packing 6x6 to 6x (n-6) @@ -9019,10 +11195,10 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB_XAlB BLIS_SET_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_DTRSM_SMALL_GEMM_6nx4m(a01,b10,cs_b,p_lda,k_iter) + BLIS_DTRSM_SMALL_GEMM_6nx3m(a01,b10,cs_b,p_lda,k_iter) // Load b11 of size 4x6 and multiply with alpha - BLIS_PRE_DTRSM_SMALL_6x4(AlphaVal,b11,cs_b) + BLIS_PRE_DTRSM_SMALL_6x3(AlphaVal,b11,cs_b) ///implement TRSM/// @@ -9105,12 +11281,12 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB_XAlB ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - _mm_storeu_pd((double *)b11, _mm256_extractf128_pd(ymm3,0)); - _mm_storeu_pd((double *)(b11 + cs_b), _mm256_extractf128_pd(ymm5,0)); - _mm_storeu_pd((double *)(b11 + cs_b*2), _mm256_extractf128_pd(ymm7,0)); - _mm_storeu_pd((double *)(b11 + cs_b*3), _mm256_extractf128_pd(ymm9,0)); - _mm_storeu_pd((double *)(b11 + cs_b*4), _mm256_extractf128_pd(ymm11,0)); - _mm_storeu_pd((double *)(b11 + cs_b*5), _mm256_extractf128_pd(ymm13,0)); + _mm_storeu_pd((double *)b11, _mm256_castpd256_pd128(ymm3)); + _mm_storeu_pd((double *)(b11 + cs_b), _mm256_castpd256_pd128(ymm5)); + _mm_storeu_pd((double *)(b11 + cs_b*2), _mm256_castpd256_pd128(ymm7)); + _mm_storeu_pd((double *)(b11 + cs_b*3), _mm256_castpd256_pd128(ymm9)); + _mm_storeu_pd((double *)(b11 + cs_b*4), _mm256_castpd256_pd128(ymm11)); + _mm_storeu_pd((double *)(b11 + cs_b*5), _mm256_castpd256_pd128(ymm13)); _mm_storel_pd((double *)b11 + 2, _mm256_extractf128_pd(ymm3,1)); _mm_storel_pd((double *)(b11 + cs_b + 2), _mm256_extractf128_pd(ymm5,1)); @@ -9134,10 +11310,10 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB_XAlB BLIS_SET_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_DTRSM_SMALL_GEMM_6nx4m(a01,b10,cs_b,p_lda,k_iter) + BLIS_DTRSM_SMALL_GEMM_6nx2m(a01,b10,cs_b,p_lda,k_iter) // Load b11 of size 4x6 and multiply with alpha - BLIS_PRE_DTRSM_SMALL_6x4(AlphaVal,b11,cs_b) + BLIS_PRE_DTRSM_SMALL_6x2(AlphaVal,b11,cs_b) ///implement TRSM/// @@ -9220,12 +11396,12 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB_XAlB ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - _mm_storeu_pd((double *)b11, _mm256_extractf128_pd(ymm3,0)); - _mm_storeu_pd((double *)(b11 + cs_b), _mm256_extractf128_pd(ymm5,0)); - _mm_storeu_pd((double *)(b11 + cs_b*2), _mm256_extractf128_pd(ymm7,0)); - _mm_storeu_pd((double *)(b11 + cs_b*3), _mm256_extractf128_pd(ymm9,0)); - _mm_storeu_pd((double *)(b11 + cs_b*4), _mm256_extractf128_pd(ymm11,0)); - _mm_storeu_pd((double *)(b11 + cs_b*5), _mm256_extractf128_pd(ymm13,0)); + _mm_storeu_pd((double *)b11, _mm256_castpd256_pd128(ymm3)); + _mm_storeu_pd((double *)(b11 + cs_b), _mm256_castpd256_pd128(ymm5)); + _mm_storeu_pd((double *)(b11 + cs_b*2), _mm256_castpd256_pd128(ymm7)); + _mm_storeu_pd((double *)(b11 + cs_b*3), _mm256_castpd256_pd128(ymm9)); + _mm_storeu_pd((double *)(b11 + cs_b*4), _mm256_castpd256_pd128(ymm11)); + _mm_storeu_pd((double *)(b11 + cs_b*5), _mm256_castpd256_pd128(ymm13)); m_remainder -=2; } @@ -9242,10 +11418,10 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB_XAlB BLIS_SET_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_DTRSM_SMALL_GEMM_6nx4m(a01,b10,cs_b,p_lda,k_iter) + BLIS_DTRSM_SMALL_GEMM_6nx1m(a01,b10,cs_b,p_lda,k_iter) // Load b11 of size 4x6 and multiply with alpha - BLIS_PRE_DTRSM_SMALL_6x4(AlphaVal,b11,cs_b) + BLIS_PRE_DTRSM_SMALL_6x1(AlphaVal,b11,cs_b) ///implement TRSM/// @@ -9328,12 +11504,12 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB_XAlB ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - _mm_storel_pd((double *)b11, _mm256_extractf128_pd(ymm3,0)); - _mm_storel_pd((double *)(b11 + cs_b), _mm256_extractf128_pd(ymm5,0)); - _mm_storel_pd((double *)(b11 + cs_b*2), _mm256_extractf128_pd(ymm7,0)); - _mm_storel_pd((double *)(b11 + cs_b*3), _mm256_extractf128_pd(ymm9,0)); - _mm_storel_pd((double *)(b11 + cs_b*4), _mm256_extractf128_pd(ymm11,0)); - _mm_storel_pd((double *)(b11 + cs_b*5), _mm256_extractf128_pd(ymm13,0)); + _mm_storel_pd((double *)b11, _mm256_castpd256_pd128(ymm3)); + _mm_storel_pd((double *)(b11 + cs_b), _mm256_castpd256_pd128(ymm5)); + _mm_storel_pd((double *)(b11 + cs_b*2), _mm256_castpd256_pd128(ymm7)); + _mm_storel_pd((double *)(b11 + cs_b*3), _mm256_castpd256_pd128(ymm9)); + _mm_storel_pd((double *)(b11 + cs_b*4), _mm256_castpd256_pd128(ymm11)); + _mm_storel_pd((double *)(b11 + cs_b*5), _mm256_castpd256_pd128(ymm13)); m_remainder -=1; } @@ -9399,10 +11575,10 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB_XAlB ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); - _mm_storeu_pd((double *)(ptr_a10_dup + 4), _mm256_extractf128_pd(ymm6,0)); - _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda), _mm256_extractf128_pd(ymm7,0)); - _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*2), _mm256_extractf128_pd(ymm8,0)); - _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*3), _mm256_extractf128_pd(ymm9,0)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4), _mm256_castpd256_pd128(ymm6)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda), _mm256_castpd256_pd128(ymm7)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*2), _mm256_castpd256_pd128(ymm8)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*3), _mm256_castpd256_pd128(ymm9)); a01 += d_nr*cs_a; ptr_a10_dup += d_nr; @@ -9654,7 +11830,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB_XAlB BLIS_SET_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_DTRSM_SMALL_GEMM_4nx4m(a01,b10,cs_b,p_lda,k_iter) + BLIS_DTRSM_SMALL_GEMM_4nx3m(a01,b10,cs_b,p_lda,k_iter) ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha @@ -9714,10 +11890,10 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB_XAlB ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - _mm_storeu_pd((double *)b11, _mm256_extractf128_pd(ymm3,0)); - _mm_storeu_pd((double *)(b11 + cs_b), _mm256_extractf128_pd(ymm5,0)); - _mm_storeu_pd((double *)(b11 + cs_b*2), _mm256_extractf128_pd(ymm7,0)); - _mm_storeu_pd((double *)(b11 + cs_b*3), _mm256_extractf128_pd(ymm9,0)); + _mm_storeu_pd((double *)b11, _mm256_castpd256_pd128(ymm3)); + _mm_storeu_pd((double *)(b11 + cs_b), _mm256_castpd256_pd128(ymm5)); + _mm_storeu_pd((double *)(b11 + cs_b*2), _mm256_castpd256_pd128(ymm7)); + _mm_storeu_pd((double *)(b11 + cs_b*3), _mm256_castpd256_pd128(ymm9)); _mm_storel_pd((double *)b11 + 2, _mm256_extractf128_pd(ymm3,1)); _mm_storel_pd((double *)(b11 + cs_b + 2), _mm256_extractf128_pd(ymm5,1)); @@ -9739,7 +11915,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB_XAlB BLIS_SET_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_DTRSM_SMALL_GEMM_4nx4m(a01,b10,cs_b,p_lda,k_iter) + BLIS_DTRSM_SMALL_GEMM_4nx2m(a01,b10,cs_b,p_lda,k_iter) ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha @@ -9798,10 +11974,10 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB_XAlB ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - _mm_storeu_pd((double *)b11, _mm256_extractf128_pd(ymm3,0)); - _mm_storeu_pd((double *)(b11 + cs_b), _mm256_extractf128_pd(ymm5,0)); - _mm_storeu_pd((double *)(b11 + cs_b*2), _mm256_extractf128_pd(ymm7,0)); - _mm_storeu_pd((double *)(b11 + cs_b*3), _mm256_extractf128_pd(ymm9,0)); + _mm_storeu_pd((double *)b11, _mm256_castpd256_pd128(ymm3)); + _mm_storeu_pd((double *)(b11 + cs_b), _mm256_castpd256_pd128(ymm5)); + _mm_storeu_pd((double *)(b11 + cs_b*2), _mm256_castpd256_pd128(ymm7)); + _mm_storeu_pd((double *)(b11 + cs_b*3), _mm256_castpd256_pd128(ymm9)); m_remainder -=2; } @@ -9818,7 +11994,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB_XAlB BLIS_SET_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_DTRSM_SMALL_GEMM_4nx4m(a01,b10,cs_b,p_lda,k_iter) + BLIS_DTRSM_SMALL_GEMM_4nx1m(a01,b10,cs_b,p_lda,k_iter) ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha @@ -9874,12 +12050,12 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB_XAlB ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); ymm3 = _mm256_fnmadd_pd(ymm1, ymm5, ymm3); - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - _mm_storel_pd((b11 + cs_b * 0), _mm256_extractf128_pd(ymm3, 0)); - _mm_storel_pd((b11 + cs_b * 1), _mm256_extractf128_pd(ymm5, 0)); - _mm_storel_pd((b11 + cs_b * 2), _mm256_extractf128_pd(ymm7, 0)); - _mm_storel_pd((b11 + cs_b * 3), _mm256_extractf128_pd(ymm9, 0)); + _mm_storel_pd((b11 + cs_b * 0), _mm256_castpd256_pd128(ymm3)); + _mm_storel_pd((b11 + cs_b * 1), _mm256_castpd256_pd128(ymm5)); + _mm_storel_pd((b11 + cs_b * 2), _mm256_castpd256_pd128(ymm7)); + _mm_storel_pd((b11 + cs_b * 3), _mm256_castpd256_pd128(ymm9)); m_remainder -=1; } @@ -9938,10 +12114,10 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB_XAlB ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); - _mm_storeu_pd((double *)(ptr_a10_dup + 4), _mm256_extractf128_pd(ymm6,0)); - _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda), _mm256_extractf128_pd(ymm7,0)); - _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*2), _mm256_extractf128_pd(ymm8,0)); - _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*3), _mm256_extractf128_pd(ymm9,0)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4), _mm256_castpd256_pd128(ymm6)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda), _mm256_castpd256_pd128(ymm7)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*2), _mm256_castpd256_pd128(ymm8)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*3), _mm256_castpd256_pd128(ymm9)); a01 += d_nr*cs_a; ptr_a10_dup += d_nr; @@ -10163,7 +12339,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB_XAlB BLIS_SET_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_DTRSM_SMALL_GEMM_3nx4m(a01,b10,cs_b,p_lda,k_iter) + BLIS_DTRSM_SMALL_GEMM_3nx3m(a01,b10,cs_b,p_lda,k_iter) BLIS_PRE_DTRSM_SMALL_3N_3M(AlphaVal,b11,cs_b) @@ -10210,7 +12386,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB_XAlB BLIS_SET_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_DTRSM_SMALL_GEMM_3nx4m(a01,b10,cs_b,p_lda,k_iter) + BLIS_DTRSM_SMALL_GEMM_3nx2m(a01,b10,cs_b,p_lda,k_iter) BLIS_PRE_DTRSM_SMALL_3N_2M(AlphaVal,b11,cs_b) @@ -10258,7 +12434,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB_XAlB BLIS_SET_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_DTRSM_SMALL_GEMM_3nx4m(a01,b10,cs_b,p_lda,k_iter) + BLIS_DTRSM_SMALL_GEMM_3nx1m(a01,b10,cs_b,p_lda,k_iter) BLIS_PRE_DTRSM_SMALL_3N_1M(AlphaVal,b11,cs_b) @@ -10347,10 +12523,10 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB_XAlB ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); - _mm_storeu_pd((double *)(ptr_a10_dup + 4), _mm256_extractf128_pd(ymm6,0)); - _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda), _mm256_extractf128_pd(ymm7,0)); - _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*2), _mm256_extractf128_pd(ymm8,0)); - _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*3), _mm256_extractf128_pd(ymm9,0)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4), _mm256_castpd256_pd128(ymm6)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda), _mm256_castpd256_pd128(ymm7)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*2), _mm256_castpd256_pd128(ymm8)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*3), _mm256_castpd256_pd128(ymm9)); a01 += d_nr*cs_a; ptr_a10_dup += d_nr; @@ -10526,7 +12702,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB_XAlB BLIS_SET_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_DTRSM_SMALL_GEMM_2nx4m(a01,b10,cs_b,p_lda,k_iter) + BLIS_DTRSM_SMALL_GEMM_2nx3m(a01,b10,cs_b,p_lda,k_iter) BLIS_PRE_DTRSM_SMALL_2N_3M(AlphaVal,b11,cs_b) @@ -10562,7 +12738,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB_XAlB BLIS_SET_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_DTRSM_SMALL_GEMM_2nx4m(a01,b10,cs_b,p_lda,k_iter) + BLIS_DTRSM_SMALL_GEMM_2nx2m(a01,b10,cs_b,p_lda,k_iter) BLIS_PRE_DTRSM_SMALL_2N_2M(AlphaVal,b11,cs_b) ///implement TRSM/// @@ -10597,7 +12773,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB_XAlB BLIS_SET_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_DTRSM_SMALL_GEMM_2nx4m(a01,b10,cs_b,p_lda,k_iter) + BLIS_DTRSM_SMALL_GEMM_2nx1m(a01,b10,cs_b,p_lda,k_iter) BLIS_PRE_DTRSM_SMALL_2N_1M(AlphaVal,b11,cs_b) ///implement TRSM/// @@ -10634,7 +12810,8 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB_XAlB if(transa) { - for(dim_t x =0;x < p_lda;x+=d_nr) + dim_t x =0; + for(x =0;(x+d_nr-1) < p_lda;x+=d_nr) { ymm0 = _mm256_loadu_pd((double const *)(a01)); ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a)); @@ -10673,14 +12850,42 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB_XAlB ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); - _mm_storeu_pd((double *)(ptr_a10_dup + 4), _mm256_extractf128_pd(ymm6,0)); - _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda), _mm256_extractf128_pd(ymm7,0)); - _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*2), _mm256_extractf128_pd(ymm8,0)); - _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*3), _mm256_extractf128_pd(ymm9,0)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4), _mm256_castpd256_pd128(ymm6)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda), _mm256_castpd256_pd128(ymm7)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*2), _mm256_castpd256_pd128(ymm8)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*3), _mm256_castpd256_pd128(ymm9)); a01 += d_nr*cs_a; ptr_a10_dup += d_nr; } + dim_t remainder_loop_count = p_lda - x; + if(remainder_loop_count >= 4) + { + ymm0 = _mm256_loadu_pd((double const *)(a01)); + ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a)); + ymm2 = _mm256_loadu_pd((double const *)(a01 + cs_a * 2)); + ymm3 = _mm256_loadu_pd((double const *)(a01 + cs_a * 3)); + + ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); + + ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); + ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); + + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); + + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + + _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda), ymm7); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*2), ymm8); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*3), ymm9); + a01 += 4*cs_a; + ptr_a10_dup += 4; + remainder_loop_count = remainder_loop_count - 4; + } } else { @@ -10801,7 +13006,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB_XAlB ymm3 = _mm256_setzero_pd(); ///GEMM implementation starts/// - BLIS_DTRSM_SMALL_GEMM_1nx4m(a01,b10,cs_b,p_lda,k_iter) + BLIS_DTRSM_SMALL_GEMM_1nx3m(a01,b10,cs_b,p_lda,k_iter) BLIS_PRE_DTRSM_SMALL_1N_3M(AlphaVal,b11,cs_b) @@ -10829,7 +13034,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB_XAlB ymm3 = _mm256_setzero_pd(); ///GEMM implementation starts/// - BLIS_DTRSM_SMALL_GEMM_1nx4m(a01,b10,cs_b,p_lda,k_iter) + BLIS_DTRSM_SMALL_GEMM_1nx2m(a01,b10,cs_b,p_lda,k_iter) BLIS_PRE_DTRSM_SMALL_1N_2M(AlphaVal,b11,cs_b) @@ -10854,7 +13059,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB_XAlB ymm3 = _mm256_setzero_pd(); ///GEMM implementation starts/// - BLIS_DTRSM_SMALL_GEMM_1nx4m(a01,b10,cs_b,p_lda,k_iter) + BLIS_DTRSM_SMALL_GEMM_1nx1m(a01,b10,cs_b,p_lda,k_iter) BLIS_PRE_DTRSM_SMALL_1N_1M(AlphaVal,b11,cs_b) @@ -10935,6 +13140,8 @@ BLIS_INLINE err_t bli_dtrsm_small_AltXB_AuXB __m128d xmm5; + xmm5 = _mm_setzero_pd(); + gint_t required_packing_A = 1; mem_t local_mem_buf_A_s = {0}; double *D_A_pack = NULL; @@ -11739,7 +13946,7 @@ BLIS_INLINE err_t bli_dtrsm_small_AltXB_AuXB dim_t p_lda = 4; // packed leading dimension if(transa) { - for(dim_t x =0;x < m-i+4;x+=p_lda) + for(dim_t x =0;x < m-i-4;x+=p_lda) { ymm0 = _mm256_loadu_pd((double const *)(a10)); ymm1 = _mm256_loadu_pd((double const *)(a10 + cs_a)); @@ -12296,35 +14503,54 @@ BLIS_INLINE err_t bli_dtrsm_small_AltXB_AuXB ///GEMM code ends/// ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to store alpha value - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); + + xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 0)); + ymm0 = _mm256_broadcast_sd((double const*)(b11 + cs_b * 0 + 2)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); + + xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 1)); + ymm1 = _mm256_broadcast_sd((double const*)(b11 + cs_b * 1 + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm5, 0); + + xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 2)); + ymm2 = _mm256_broadcast_sd((double const*)(b11 + cs_b * 2 + 2)); + ymm2 = _mm256_insertf128_pd(ymm2, xmm5, 0); + + xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 3)); + ymm3 = _mm256_broadcast_sd((double const*)(b11 + cs_b * 3 + 2)); + ymm3 = _mm256_insertf128_pd(ymm3, xmm5, 0); + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x08); - ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x08); - ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x08); + _mm_storeu_pd((double *)(b11), _mm256_castpd256_pd128(ymm8)); + _mm_storeu_pd((double *)(b11 + cs_b * 1), _mm256_castpd256_pd128(ymm9)); + _mm_storeu_pd((double *)(b11 + cs_b * 2), _mm256_castpd256_pd128(ymm10)); + _mm_storeu_pd((double *)(b11 + cs_b * 3), _mm256_castpd256_pd128(ymm11)); + + _mm_storel_pd((double *)(b11 + 2), _mm256_extractf128_pd(ymm8,1)); + _mm_storel_pd((double *)(b11 + cs_b * 1 + 2), _mm256_extractf128_pd(ymm9,1)); + _mm_storel_pd((double *)(b11 + cs_b * 2 + 2), _mm256_extractf128_pd(ymm10,1)); + _mm_storel_pd((double *)(b11 + cs_b * 3 + 2), _mm256_extractf128_pd(ymm11,1)); + + xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 4)); + ymm0 = _mm256_broadcast_sd((double const*)(b11 + cs_b * 4 + 2)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); - _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) - _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[0-3][3]) + xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 5)); + ymm1 = _mm256_broadcast_sd((double const*)(b11 + cs_b * 5 + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm5, 0); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *4)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *5)); - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm4); - ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm5); + ymm4 = _mm256_fmsub_pd(ymm0, ymm16, ymm4); + ymm5 = _mm256_fmsub_pd(ymm1, ymm16, ymm5); - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x08); + _mm_storeu_pd((double *)(b11 + cs_b * 4), _mm256_castpd256_pd128(ymm4)); + _mm_storeu_pd((double *)(b11 + cs_b * 5), _mm256_castpd256_pd128(ymm5)); - _mm256_storeu_pd((double *)(b11 + cs_b * 4), ymm0); //store(B11[0-3][2]) - _mm256_storeu_pd((double *)(b11 + cs_b * 5), ymm1); //store(B11[0-3][3]) + _mm_storel_pd((double *)(b11 + cs_b * 4 + 2), _mm256_extractf128_pd(ymm4,1)); + _mm_storel_pd((double *)(b11 + cs_b * 5 + 2), _mm256_extractf128_pd(ymm5,1)); if(transa) dtrsm_AltXB_ref(a11, b11, m_remainder, 6, cs_a, cs_b, is_unitdiag); @@ -12352,11 +14578,20 @@ BLIS_INLINE err_t bli_dtrsm_small_AltXB_AuXB ///implement TRSM/// - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); + xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 0)); + ymm0 = _mm256_broadcast_sd((double const*)(b11 + cs_b * 0 + 2)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); + + xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 1)); + ymm1 = _mm256_broadcast_sd((double const*)(b11 + cs_b * 1 + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm5, 0); + + xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 2)); + ymm2 = _mm256_broadcast_sd((double const*)(b11 + cs_b * 2 + 2)); + ymm2 = _mm256_insertf128_pd(ymm2, xmm5, 0); + xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 3)); - ymm3 = _mm256_broadcast_sd((double const *)(b11 + cs_b*3 + 2)); + ymm3 = _mm256_broadcast_sd((double const*)(b11 + cs_b * 3 + 2)); ymm3 = _mm256_insertf128_pd(ymm3, xmm5, 0); ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); @@ -12364,17 +14599,15 @@ BLIS_INLINE err_t bli_dtrsm_small_AltXB_AuXB ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x08); - ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x08); - ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x08); + _mm_storeu_pd((double *)(b11), _mm256_castpd256_pd128(ymm8)); + _mm_storeu_pd((double *)(b11 + cs_b * 1), _mm256_castpd256_pd128(ymm9)); + _mm_storeu_pd((double *)(b11 + cs_b * 2), _mm256_castpd256_pd128(ymm10)); + _mm_storeu_pd((double *)(b11 + cs_b * 3), _mm256_castpd256_pd128(ymm11)); - _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) - xmm5 = _mm256_extractf128_pd(ymm3, 0); - _mm_storeu_pd((double *)(b11 + cs_b * 3),xmm5); - _mm_storel_pd((b11 + cs_b * 3 + 2), _mm256_extractf128_pd(ymm3, 1)); + _mm_storel_pd((double *)(b11 + 2), _mm256_extractf128_pd(ymm8,1)); + _mm_storel_pd((double *)(b11 + cs_b * 1 + 2), _mm256_extractf128_pd(ymm9,1)); + _mm_storel_pd((double *)(b11 + cs_b * 2 + 2), _mm256_extractf128_pd(ymm10,1)); + _mm_storel_pd((double *)(b11 + cs_b * 3 + 2), _mm256_extractf128_pd(ymm11,1)); if(transa) dtrsm_AltXB_ref(a11, b11, m_remainder, 4, cs_a, cs_b, is_unitdiag); @@ -12492,35 +14725,40 @@ BLIS_INLINE err_t bli_dtrsm_small_AltXB_AuXB ///GEMM code ends/// ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to store alpha value - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); + + xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 0)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); + + xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 1)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm5, 0); + + xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 2)); + ymm2 = _mm256_insertf128_pd(ymm2, xmm5, 0); + + xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 3)); + ymm3 = _mm256_insertf128_pd(ymm3, xmm5, 0); + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0C); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0C); - ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0C); - ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x0C); + _mm_storeu_pd((double *)(b11), _mm256_castpd256_pd128(ymm8)); + _mm_storeu_pd((double *)(b11 + cs_b * 1), _mm256_castpd256_pd128(ymm9)); + _mm_storeu_pd((double *)(b11 + cs_b * 2), _mm256_castpd256_pd128(ymm10)); + _mm_storeu_pd((double *)(b11 + cs_b * 3), _mm256_castpd256_pd128(ymm11)); - _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) - _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[0-3][3]) + xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 4)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *4)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *5)); - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm4); - ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm5); + xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 5)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm5, 0); - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0C); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0C); + ymm4 = _mm256_fmsub_pd(ymm0, ymm16, ymm4); + ymm5 = _mm256_fmsub_pd(ymm1, ymm16, ymm5); - _mm256_storeu_pd((double *)(b11 + cs_b * 4), ymm0); //store(B11[0-3][2]) - _mm256_storeu_pd((double *)(b11 + cs_b * 5), ymm1); //store(B11[0-3][3]) + _mm_storeu_pd((double *)(b11 + cs_b * 4), _mm256_castpd256_pd128(ymm4)); + _mm_storeu_pd((double *)(b11 + cs_b * 5), _mm256_castpd256_pd128(ymm5)); if(transa) dtrsm_AltXB_ref(a11, b11, m_remainder, 6, cs_a, cs_b, is_unitdiag); @@ -12547,9 +14785,15 @@ BLIS_INLINE err_t bli_dtrsm_small_AltXB_AuXB ///implement TRSM/// - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); + xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 0)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); + + xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 1)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm5, 0); + + xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 2)); + ymm2 = _mm256_insertf128_pd(ymm2, xmm5, 0); + xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 3)); ymm3 = _mm256_insertf128_pd(ymm3, xmm5, 0); @@ -12558,16 +14802,10 @@ BLIS_INLINE err_t bli_dtrsm_small_AltXB_AuXB ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0C); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0C); - ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0C); - ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x0C); - - _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) - xmm5 = _mm256_extractf128_pd(ymm3, 0); - _mm_storeu_pd((double *)(b11 + cs_b * 3), xmm5); + _mm_storeu_pd((double *)(b11), _mm256_castpd256_pd128(ymm8)); + _mm_storeu_pd((double *)(b11 + cs_b * 1), _mm256_castpd256_pd128(ymm9)); + _mm_storeu_pd((double *)(b11 + cs_b * 2), _mm256_castpd256_pd128(ymm10)); + _mm_storeu_pd((double *)(b11 + cs_b * 3), _mm256_castpd256_pd128(ymm11)); if(transa) dtrsm_AltXB_ref(a11, b11, m_remainder, 4, cs_a, cs_b, is_unitdiag); @@ -12684,35 +14922,30 @@ BLIS_INLINE err_t bli_dtrsm_small_AltXB_AuXB ///GEMM code ends/// ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to store alpha value - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); + + ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_broadcast_sd((double const *)(b11 + cs_b *1)); + ymm2 = _mm256_broadcast_sd((double const *)(b11 + cs_b *2)); + ymm3 = _mm256_broadcast_sd((double const *)(b11 + cs_b *3)); + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0E); - ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0E); - ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x0E); + _mm_storel_pd((double *)(b11), _mm256_extractf128_pd(ymm8,0)); + _mm_storel_pd((double *)(b11 + cs_b * 1), _mm256_extractf128_pd(ymm9,0)); + _mm_storel_pd((double *)(b11 + cs_b * 2), _mm256_extractf128_pd(ymm10,0)); + _mm_storel_pd((double *)(b11 + cs_b * 3), _mm256_extractf128_pd(ymm11,0)); - _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) - _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[0-3][3]) + ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b *4)); + ymm1 = _mm256_broadcast_sd((double const *)(b11 + cs_b *5)); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *4)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *5)); - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm4); - ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm5); + ymm4 = _mm256_fmsub_pd(ymm0, ymm16, ymm4); + ymm5 = _mm256_fmsub_pd(ymm1, ymm16, ymm5); - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0E); - - _mm256_storeu_pd((double *)(b11 + cs_b * 4), ymm0); //store(B11[0-3][2]) - _mm256_storeu_pd((double *)(b11 + cs_b * 5), ymm1); //store(B11[0-3][3]) + _mm_storel_pd((double *)(b11 + cs_b * 4), _mm256_extractf128_pd(ymm4,0)); + _mm_storel_pd((double *)(b11 + cs_b * 5), _mm256_extractf128_pd(ymm5,0)); if(transa) dtrsm_AltXB_ref(a11, b11, m_remainder, 6, cs_a, cs_b, is_unitdiag); @@ -12739,24 +14972,20 @@ BLIS_INLINE err_t bli_dtrsm_small_AltXB_AuXB ///implement TRSM/// - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); + ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_broadcast_sd((double const *)(b11 + cs_b *1)); + ymm2 = _mm256_broadcast_sd((double const *)(b11 + cs_b *2)); + ymm3 = _mm256_broadcast_sd((double const *)(b11 + cs_b *3)); + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0E); - ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0E); - ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x0E); - - _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) - _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[0-3][3]) + _mm_storel_pd((double *)(b11), _mm256_extractf128_pd(ymm8,0)); + _mm_storel_pd((double *)(b11 + cs_b * 1), _mm256_extractf128_pd(ymm9,0)); + _mm_storel_pd((double *)(b11 + cs_b * 2), _mm256_extractf128_pd(ymm10,0)); + _mm_storel_pd((double *)(b11 + cs_b * 3), _mm256_extractf128_pd(ymm11,0)); if(transa) dtrsm_AltXB_ref(a11, b11, m_remainder, 4, cs_a, cs_b, is_unitdiag); @@ -12823,6 +15052,7 @@ BLIS_INLINE err_t bli_dtrsm_small_AltXB_AuXB return BLIS_SUCCESS; } + /* TRSM for the Left Upper case AX = alpha * B, Double precision * A is Left side, upper-triangular, transpose, non-unit/unit diagonal * dimensions A: mxm X: mxn B: mxn @@ -12914,6 +15144,8 @@ BLIS_INLINE err_t bli_dtrsm_small_AutXB_AlXB __m128d xmm5; + xmm5 = _mm_setzero_pd(); + gint_t required_packing_A = 1; mem_t local_mem_buf_A_s = {0}; double *D_A_pack = NULL; @@ -14356,35 +16588,53 @@ BLIS_INLINE err_t bli_dtrsm_small_AutXB_AlXB ///GEMM code ends/// ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to store alpha value - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); + xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 0)); + ymm0 = _mm256_broadcast_sd((double const*)(b11 + cs_b * 0 + 2)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); + + xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 1)); + ymm1 = _mm256_broadcast_sd((double const*)(b11 + cs_b * 1 + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm5, 0); + + xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 2)); + ymm2 = _mm256_broadcast_sd((double const*)(b11 + cs_b * 2 + 2)); + ymm2 = _mm256_insertf128_pd(ymm2, xmm5, 0); + + xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 3)); + ymm3 = _mm256_broadcast_sd((double const*)(b11 + cs_b * 3 + 2)); + ymm3 = _mm256_insertf128_pd(ymm3, xmm5, 0); + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x08); - ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x08); - ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x08); + _mm_storeu_pd((double *)(b11), _mm256_castpd256_pd128(ymm8)); + _mm_storeu_pd((double *)(b11 + cs_b * 1), _mm256_castpd256_pd128(ymm9)); + _mm_storeu_pd((double *)(b11 + cs_b * 2), _mm256_castpd256_pd128(ymm10)); + _mm_storeu_pd((double *)(b11 + cs_b * 3), _mm256_castpd256_pd128(ymm11)); + + _mm_storel_pd((double *)(b11 + 2), _mm256_extractf128_pd(ymm8,1)); + _mm_storel_pd((double *)(b11 + cs_b * 1 + 2), _mm256_extractf128_pd(ymm9,1)); + _mm_storel_pd((double *)(b11 + cs_b * 2 + 2), _mm256_extractf128_pd(ymm10,1)); + _mm_storel_pd((double *)(b11 + cs_b * 3 + 2), _mm256_extractf128_pd(ymm11,1)); + + xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 4)); + ymm0 = _mm256_broadcast_sd((double const*)(b11 + cs_b * 4 + 2)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); - _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) - _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[0-3][3]) + xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 5)); + ymm1 = _mm256_broadcast_sd((double const*)(b11 + cs_b * 5 + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm5, 0); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *4)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *5)); - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm4); - ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm5); + ymm4 = _mm256_fmsub_pd(ymm0, ymm16, ymm4); + ymm5 = _mm256_fmsub_pd(ymm1, ymm16, ymm5); - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x08); + _mm_storeu_pd((double *)(b11 + cs_b * 4), _mm256_castpd256_pd128(ymm4)); + _mm_storeu_pd((double *)(b11 + cs_b * 5), _mm256_castpd256_pd128(ymm5)); - _mm256_storeu_pd((double *)(b11 + cs_b * 4), ymm0); //store(B11[0-3][2]) - _mm256_storeu_pd((double *)(b11 + cs_b * 5), ymm1); //store(B11[0-3][3]) + _mm_storel_pd((double *)(b11 + cs_b * 4 + 2), _mm256_extractf128_pd(ymm4,1)); + _mm_storel_pd((double *)(b11 + cs_b * 5 + 2), _mm256_extractf128_pd(ymm5,1)); if(transa) dtrsm_AutXB_ref(a11, b11, m_rem, 6, cs_a, cs_b,is_unitdiag); @@ -14412,28 +16662,36 @@ BLIS_INLINE err_t bli_dtrsm_small_AutXB_AlXB ///implement TRSM/// - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); + xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 0)); + ymm0 = _mm256_broadcast_sd((double const*)(b11 + cs_b * 0 + 2)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); + + xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 1)); + ymm1 = _mm256_broadcast_sd((double const*)(b11 + cs_b * 1 + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm5, 0); + + xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 2)); + ymm2 = _mm256_broadcast_sd((double const*)(b11 + cs_b * 2 + 2)); + ymm2 = _mm256_insertf128_pd(ymm2, xmm5, 0); + xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 3)); - ymm3 = _mm256_broadcast_sd((double const *)(b11 + cs_b*3 + 2)); + ymm3 = _mm256_broadcast_sd((double const*)(b11 + cs_b * 3 + 2)); ymm3 = _mm256_insertf128_pd(ymm3, xmm5, 0); + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x08); - ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x08); - ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x08); + _mm_storeu_pd((double *)(b11), _mm256_castpd256_pd128(ymm8)); + _mm_storeu_pd((double *)(b11 + cs_b * 1), _mm256_castpd256_pd128(ymm9)); + _mm_storeu_pd((double *)(b11 + cs_b * 2), _mm256_castpd256_pd128(ymm10)); + _mm_storeu_pd((double *)(b11 + cs_b * 3), _mm256_castpd256_pd128(ymm11)); - _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) - xmm5 = _mm256_extractf128_pd(ymm3, 0); - _mm_storeu_pd((double *)(b11 + cs_b * 3),xmm5); - _mm_storel_pd((b11 + cs_b * 3 + 2), _mm256_extractf128_pd(ymm3, 1)); + _mm_storel_pd((double *)(b11 + 2), _mm256_extractf128_pd(ymm8,1)); + _mm_storel_pd((double *)(b11 + cs_b * 1 + 2), _mm256_extractf128_pd(ymm9,1)); + _mm_storel_pd((double *)(b11 + cs_b * 2 + 2), _mm256_extractf128_pd(ymm10,1)); + _mm_storel_pd((double *)(b11 + cs_b * 3 + 2), _mm256_extractf128_pd(ymm11,1)); if(transa) dtrsm_AutXB_ref(a11, b11, m_rem, 4, cs_a, cs_b,is_unitdiag); @@ -14553,35 +16811,39 @@ BLIS_INLINE err_t bli_dtrsm_small_AutXB_AlXB ///GEMM code ends/// ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to store alpha value - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); + xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 0)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); + + xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 1)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm5, 0); + + xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 2)); + ymm2 = _mm256_insertf128_pd(ymm2, xmm5, 0); + + xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 3)); + ymm3 = _mm256_insertf128_pd(ymm3, xmm5, 0); + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0C); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0C); - ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0C); - ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x0C); + _mm_storeu_pd((double *)(b11), _mm256_castpd256_pd128(ymm8)); + _mm_storeu_pd((double *)(b11 + cs_b * 1), _mm256_castpd256_pd128(ymm9)); + _mm_storeu_pd((double *)(b11 + cs_b * 2), _mm256_castpd256_pd128(ymm10)); + _mm_storeu_pd((double *)(b11 + cs_b * 3), _mm256_castpd256_pd128(ymm11)); - _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) - _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[0-3][3]) + xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 4)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *4)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *5)); - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm4); - ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm5); + xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 5)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm5, 0); - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0C); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0C); + ymm4 = _mm256_fmsub_pd(ymm0, ymm16, ymm4); + ymm5 = _mm256_fmsub_pd(ymm1, ymm16, ymm5); - _mm256_storeu_pd((double *)(b11 + cs_b * 4), ymm0); //store(B11[0-3][2]) - _mm256_storeu_pd((double *)(b11 + cs_b * 5), ymm1); //store(B11[0-3][3]) + _mm_storeu_pd((double *)(b11 + cs_b * 4), _mm256_castpd256_pd128(ymm4)); + _mm_storeu_pd((double *)(b11 + cs_b * 5), _mm256_castpd256_pd128(ymm5)); if(transa) dtrsm_AutXB_ref(a11, b11, m_rem, 6, cs_a, cs_b, is_unitdiag); @@ -14609,9 +16871,15 @@ BLIS_INLINE err_t bli_dtrsm_small_AutXB_AlXB ///implement TRSM/// - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); + xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 0)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); + + xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 1)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm5, 0); + + xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 2)); + ymm2 = _mm256_insertf128_pd(ymm2, xmm5, 0); + xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 3)); ymm3 = _mm256_insertf128_pd(ymm3, xmm5, 0); @@ -14620,16 +16888,10 @@ BLIS_INLINE err_t bli_dtrsm_small_AutXB_AlXB ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0C); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0C); - ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0C); - ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x0C); - - _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) - xmm5 = _mm256_extractf128_pd(ymm3, 0); - _mm_storeu_pd((double *)(b11 + cs_b * 3), xmm5); + _mm_storeu_pd((double *)(b11), _mm256_castpd256_pd128(ymm8)); + _mm_storeu_pd((double *)(b11 + cs_b * 1), _mm256_castpd256_pd128(ymm9)); + _mm_storeu_pd((double *)(b11 + cs_b * 2), _mm256_castpd256_pd128(ymm10)); + _mm_storeu_pd((double *)(b11 + cs_b * 3), _mm256_castpd256_pd128(ymm11)); if(transa) dtrsm_AutXB_ref(a11, b11, m_rem, 4, cs_a, cs_b, is_unitdiag); @@ -14750,35 +17012,29 @@ BLIS_INLINE err_t bli_dtrsm_small_AutXB_AlXB ///GEMM code ends/// ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to store alpha value - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); + ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_broadcast_sd((double const *)(b11 + cs_b *1)); + ymm2 = _mm256_broadcast_sd((double const *)(b11 + cs_b *2)); + ymm3 = _mm256_broadcast_sd((double const *)(b11 + cs_b *3)); + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0E); - ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0E); - ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x0E); + _mm_storel_pd((double *)(b11), _mm256_extractf128_pd(ymm8,0)); + _mm_storel_pd((double *)(b11 + cs_b * 1), _mm256_extractf128_pd(ymm9,0)); + _mm_storel_pd((double *)(b11 + cs_b * 2), _mm256_extractf128_pd(ymm10,0)); + _mm_storel_pd((double *)(b11 + cs_b * 3), _mm256_extractf128_pd(ymm11,0)); - _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) - _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[0-3][3]) + ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b *4)); + ymm1 = _mm256_broadcast_sd((double const *)(b11 + cs_b *5)); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *4)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *5)); - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm4); - ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm5); + ymm4 = _mm256_fmsub_pd(ymm0, ymm16, ymm4); + ymm5 = _mm256_fmsub_pd(ymm1, ymm16, ymm5); - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0E); - - _mm256_storeu_pd((double *)(b11 + cs_b * 4), ymm0); //store(B11[0-3][2]) - _mm256_storeu_pd((double *)(b11 + cs_b * 5), ymm1); //store(B11[0-3][3]) + _mm_storel_pd((double *)(b11 + cs_b * 4), _mm256_extractf128_pd(ymm4,0)); + _mm_storel_pd((double *)(b11 + cs_b * 5), _mm256_extractf128_pd(ymm5,0)); if(transa) dtrsm_AutXB_ref(a11, b11, m_rem, 6, cs_a, cs_b, is_unitdiag); @@ -14814,15 +17070,10 @@ BLIS_INLINE err_t bli_dtrsm_small_AutXB_AlXB ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0E); - ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0E); - ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x0E); - - _mm_storel_pd((b11 + cs_b * 0), _mm256_extractf128_pd(ymm0, 0)); - _mm_storel_pd((b11 + cs_b * 1), _mm256_extractf128_pd(ymm1, 0)); - _mm_storel_pd((b11 + cs_b * 2), _mm256_extractf128_pd(ymm2, 0)); - _mm_storel_pd((b11 + cs_b * 3), _mm256_extractf128_pd(ymm3, 0)); + _mm_storel_pd((b11 + cs_b * 0), _mm256_extractf128_pd(ymm8,0)); + _mm_storel_pd((b11 + cs_b * 1), _mm256_extractf128_pd(ymm9,0)); + _mm_storel_pd((b11 + cs_b * 2), _mm256_extractf128_pd(ymm10,0)); + _mm_storel_pd((b11 + cs_b * 3), _mm256_extractf128_pd(ymm11,0)); if(transa) dtrsm_AutXB_ref(a11, b11, m_rem, 4, cs_a, cs_b, is_unitdiag); @@ -15521,6 +17772,8 @@ BLIS_INLINE err_t bli_strsm_small_XAutB_XAlB __m128 xmm5; + xmm5 = _mm_setzero_ps(); + /* Performs solving TRSM for 6 rows at a time from 0 to n/6 in steps of d_nr a. Load and pack A (a01 block), the size of packing 6x6 to 6x (n-6) @@ -15860,10 +18113,10 @@ BLIS_INLINE err_t bli_strsm_small_XAutB_XAlB BLIS_SET_S_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_6nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_6nx7m(a01,b10,cs_b,p_lda,k_iter) // Load b11 of size 4x6 and multiply with alpha - BLIS_PRE_STRSM_SMALL_6x8(AlphaVal,b11,cs_b) + BLIS_PRE_STRSM_SMALL_6x7(AlphaVal,b11,cs_b) ///implement TRSM/// @@ -15946,25 +18199,29 @@ BLIS_INLINE err_t bli_strsm_small_XAutB_XAlB ymm3 = STRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - ymm0 = _mm256_loadu_ps((float const *)b11); - ymm3 = _mm256_blend_ps(ymm0, ymm3, 0x7F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_ps(ymm0, ymm5, 0x7F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_ps(ymm0, ymm7, 0x7F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm9 = _mm256_blend_ps(ymm0, ymm9, 0x7F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*4)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm11 = _mm256_blend_ps(ymm0, ymm11, 0x7F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*5)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm13 = _mm256_blend_ps(ymm0, ymm13, 0x7F); - - _mm256_storeu_ps((float *)b11, ymm3); - _mm256_storeu_ps((float *)(b11 + cs_b), ymm5); - _mm256_storeu_ps((float *)(b11 + cs_b*2), ymm7); - _mm256_storeu_ps((float *)(b11 + cs_b*3), ymm9); - _mm256_storeu_ps((float *)(b11 + cs_b*4), ymm11); - _mm256_storeu_ps((float *)(b11 + cs_b*5), ymm13); + _mm_storeu_ps((float *)(b11),_mm256_extractf128_ps(ymm3, 0)); + _mm_storel_pi((__m64 *)(b11 + 4),_mm256_extractf128_ps(ymm3, 1)); + _mm_store_ss((float *)(b11 + 6),_mm256_extractf128_ps(_mm256_unpackhi_ps(ymm3,ymm3), 1)); + + _mm_storeu_ps((float *)(b11 + cs_b),_mm256_extractf128_ps(ymm5, 0)); + _mm_storel_pi((__m64 *)(b11 + 4 + cs_b),_mm256_extractf128_ps(ymm5, 1)); + _mm_store_ss((float *)(b11 + 6 + cs_b),_mm256_extractf128_ps(_mm256_unpackhi_ps(ymm5,ymm5), 1)); + + _mm_storeu_ps((float *)(b11 + cs_b*2),_mm256_extractf128_ps(ymm7, 0)); + _mm_storel_pi((__m64 *)(b11 + 4 + cs_b*2),_mm256_extractf128_ps(ymm7, 1)); + _mm_store_ss((float *)(b11 + 6 + cs_b*2),_mm256_extractf128_ps(_mm256_unpackhi_ps(ymm7,ymm7), 1)); + + _mm_storeu_ps((float *)(b11 + cs_b*3),_mm256_extractf128_ps(ymm9, 0)); + _mm_storel_pi((__m64 *)(b11 + 4 + cs_b*3),_mm256_extractf128_ps(ymm9, 1)); + _mm_store_ss((float *)(b11 + 6 + cs_b*3),_mm256_extractf128_ps(_mm256_unpackhi_ps(ymm9,ymm9), 1)); + + _mm_storeu_ps((float *)(b11 + cs_b*4),_mm256_extractf128_ps(ymm11, 0)); + _mm_storel_pi((__m64 *)(b11 + 4 + cs_b*4),_mm256_extractf128_ps(ymm11, 1)); + _mm_store_ss((float *)(b11 + 6 + cs_b*4),_mm256_extractf128_ps(_mm256_unpackhi_ps(ymm11,ymm11), 1)); + + _mm_storeu_ps((float *)(b11 + cs_b*5),_mm256_extractf128_ps(ymm13, 0)); + _mm_storel_pi((__m64 *)(b11 + 4 + cs_b*5),_mm256_extractf128_ps(ymm13, 1)); + _mm_store_ss((float *)(b11 + 6 + cs_b*5),_mm256_extractf128_ps(_mm256_unpackhi_ps(ymm13,ymm13), 1)); m_remainder -=7; } @@ -15981,10 +18238,10 @@ BLIS_INLINE err_t bli_strsm_small_XAutB_XAlB BLIS_SET_S_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_6nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_6nx6m(a01,b10,cs_b,p_lda,k_iter) // Load b11 of size 4x6 and multiply with alpha - BLIS_PRE_STRSM_SMALL_6x8(AlphaVal,b11,cs_b) + BLIS_PRE_STRSM_SMALL_6x6(AlphaVal,b11,cs_b) ///implement TRSM/// @@ -16067,25 +18324,18 @@ BLIS_INLINE err_t bli_strsm_small_XAutB_XAlB ymm3 = STRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - ymm0 = _mm256_loadu_ps((float const *)b11); - ymm3 = _mm256_blend_ps(ymm0, ymm3, 0x3F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_ps(ymm0, ymm5, 0x3F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_ps(ymm0, ymm7, 0x3F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm9 = _mm256_blend_ps(ymm0, ymm9, 0x3F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*4)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm11 = _mm256_blend_ps(ymm0, ymm11, 0x3F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*5)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm13 = _mm256_blend_ps(ymm0, ymm13, 0x3F); - - _mm256_storeu_ps((float *)b11, ymm3); - _mm256_storeu_ps((float *)(b11 + cs_b), ymm5); - _mm256_storeu_ps((float *)(b11 + cs_b*2), ymm7); - _mm256_storeu_ps((float *)(b11 + cs_b*3), ymm9); - _mm256_storeu_ps((float *)(b11 + cs_b*4), ymm11); - _mm256_storeu_ps((float *)(b11 + cs_b*5), ymm13); + _mm_storeu_ps((float *)(b11),_mm256_extractf128_ps(ymm3, 0)); + _mm_storel_pi((__m64 *)(b11 + 4),_mm256_extractf128_ps(ymm3, 1)); + _mm_storeu_ps((float *)(b11 + cs_b),_mm256_extractf128_ps(ymm5, 0)); + _mm_storel_pi((__m64 *)(b11 + 4 + cs_b),_mm256_extractf128_ps(ymm5, 1)); + _mm_storeu_ps((float *)(b11 + cs_b*2),_mm256_extractf128_ps(ymm7, 0)); + _mm_storel_pi((__m64 *)(b11 + 4 + cs_b*2),_mm256_extractf128_ps(ymm7, 1)); + _mm_storeu_ps((float *)(b11 + cs_b*3),_mm256_extractf128_ps(ymm9, 0)); + _mm_storel_pi((__m64 *)(b11 + 4 + cs_b*3),_mm256_extractf128_ps(ymm9, 1)); + _mm_storeu_ps((float *)(b11 + cs_b*4),_mm256_extractf128_ps(ymm11, 0)); + _mm_storel_pi((__m64 *)(b11 + 4 + cs_b*4),_mm256_extractf128_ps(ymm11, 1)); + _mm_storeu_ps((float *)(b11 + cs_b*5),_mm256_extractf128_ps(ymm13, 0)); + _mm_storel_pi((__m64 *)(b11 + 4 + cs_b*5),_mm256_extractf128_ps(ymm13, 1)); m_remainder -=6; } @@ -16102,10 +18352,10 @@ BLIS_INLINE err_t bli_strsm_small_XAutB_XAlB BLIS_SET_S_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_6nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_6nx5m(a01,b10,cs_b,p_lda,k_iter) // Load b11 of size 4x6 and multiply with alpha - BLIS_PRE_STRSM_SMALL_6x8(AlphaVal,b11,cs_b) + BLIS_PRE_STRSM_SMALL_6x5(AlphaVal,b11,cs_b) ///implement TRSM/// @@ -16188,25 +18438,18 @@ BLIS_INLINE err_t bli_strsm_small_XAutB_XAlB ymm3 = STRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - ymm0 = _mm256_loadu_ps((float const *)b11); - ymm3 = _mm256_blend_ps(ymm0, ymm3, 0x1F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_ps(ymm0, ymm5, 0x1F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_ps(ymm0, ymm7, 0x1F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm9 = _mm256_blend_ps(ymm0, ymm9, 0x1F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*4)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm11 = _mm256_blend_ps(ymm0, ymm11, 0x1F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*5)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm13 = _mm256_blend_ps(ymm0, ymm13, 0x1F); - - _mm256_storeu_ps((float *)b11, ymm3); - _mm256_storeu_ps((float *)(b11 + cs_b), ymm5); - _mm256_storeu_ps((float *)(b11 + cs_b*2), ymm7); - _mm256_storeu_ps((float *)(b11 + cs_b*3), ymm9); - _mm256_storeu_ps((float *)(b11 + cs_b*4), ymm11); - _mm256_storeu_ps((float *)(b11 + cs_b*5), ymm13); + _mm_storeu_ps((float *)(b11),_mm256_extractf128_ps(ymm3, 0)); + _mm_store_ss((float *)(b11 + 4),_mm256_extractf128_ps(ymm3, 1)); + _mm_storeu_ps((float *)(b11 + cs_b),_mm256_extractf128_ps(ymm5, 0)); + _mm_store_ss((float *)(b11 + 4 + cs_b),_mm256_extractf128_ps(ymm5, 1)); + _mm_storeu_ps((float *)(b11 + cs_b*2),_mm256_extractf128_ps(ymm7, 0)); + _mm_store_ss((float *)(b11 + 4 + cs_b*2),_mm256_extractf128_ps(ymm7, 1)); + _mm_storeu_ps((float *)(b11 + cs_b*3),_mm256_extractf128_ps(ymm9, 0)); + _mm_store_ss((float *)(b11 + 4 + cs_b*3),_mm256_extractf128_ps(ymm9, 1)); + _mm_storeu_ps((float *)(b11 + cs_b*4),_mm256_extractf128_ps(ymm11, 0)); + _mm_store_ss((float *)(b11 + 4 + cs_b*4),_mm256_extractf128_ps(ymm11, 1)); + _mm_storeu_ps((float *)(b11 + cs_b*5),_mm256_extractf128_ps(ymm13, 0)); + _mm_store_ss((float *)(b11 + 4 + cs_b*5),_mm256_extractf128_ps(ymm13, 1)); m_remainder -=5; } @@ -16223,10 +18466,10 @@ BLIS_INLINE err_t bli_strsm_small_XAutB_XAlB BLIS_SET_S_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_6nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_6nx4m(a01,b10,cs_b,p_lda,k_iter) // Load b11 of size 4x6 and multiply with alpha - BLIS_PRE_STRSM_SMALL_6x8(AlphaVal,b11,cs_b) + BLIS_PRE_STRSM_SMALL_6x4(AlphaVal,b11,cs_b) ///implement TRSM/// @@ -16309,25 +18552,12 @@ BLIS_INLINE err_t bli_strsm_small_XAutB_XAlB ymm3 = STRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - ymm0 = _mm256_loadu_ps((float const *)b11); - ymm3 = _mm256_blend_ps(ymm0, ymm3, 0x0F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_ps(ymm0, ymm5, 0x0F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_ps(ymm0, ymm7, 0x0F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm9 = _mm256_blend_ps(ymm0, ymm9, 0x0F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*4)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm11 = _mm256_blend_ps(ymm0, ymm11, 0x0F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*5)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm13 = _mm256_blend_ps(ymm0, ymm13, 0x0F); - - _mm256_storeu_ps((float *)b11, ymm3); - _mm256_storeu_ps((float *)(b11 + cs_b), ymm5); - _mm256_storeu_ps((float *)(b11 + cs_b*2), ymm7); - _mm256_storeu_ps((float *)(b11 + cs_b*3), ymm9); - _mm256_storeu_ps((float *)(b11 + cs_b*4), ymm11); - _mm256_storeu_ps((float *)(b11 + cs_b*5), ymm13); + _mm_storeu_ps((float *)(b11),_mm256_extractf128_ps(ymm3, 0)); + _mm_storeu_ps((float *)(b11 + cs_b),_mm256_extractf128_ps(ymm5, 0)); + _mm_storeu_ps((float *)(b11 + cs_b*2),_mm256_extractf128_ps(ymm7, 0)); + _mm_storeu_ps((float *)(b11 + cs_b*3),_mm256_extractf128_ps(ymm9, 0)); + _mm_storeu_ps((float *)(b11 + cs_b*4),_mm256_extractf128_ps(ymm11, 0)); + _mm_storeu_ps((float *)(b11 + cs_b*5),_mm256_extractf128_ps(ymm13, 0)); m_remainder -=4; } @@ -16344,10 +18574,10 @@ BLIS_INLINE err_t bli_strsm_small_XAutB_XAlB BLIS_SET_S_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_6nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_6nx3m(a01,b10,cs_b,p_lda,k_iter) // Load b11 of size 4x6 and multiply with alpha - BLIS_PRE_STRSM_SMALL_6x8(AlphaVal,b11,cs_b) + BLIS_PRE_STRSM_SMALL_6x3(AlphaVal,b11,cs_b) ///implement TRSM/// @@ -16430,25 +18660,29 @@ BLIS_INLINE err_t bli_strsm_small_XAutB_XAlB ymm3 = STRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - ymm0 = _mm256_loadu_ps((float const *)b11); - ymm3 = _mm256_blend_ps(ymm0, ymm3, 0x07); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_ps(ymm0, ymm5, 0x07); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_ps(ymm0, ymm7, 0x07); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm9 = _mm256_blend_ps(ymm0, ymm9, 0x07); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*4)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm11 = _mm256_blend_ps(ymm0, ymm11, 0x07); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*5)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm13 = _mm256_blend_ps(ymm0, ymm13, 0x07); - - _mm256_storeu_ps((float *)b11, ymm3); - _mm256_storeu_ps((float *)(b11 + cs_b), ymm5); - _mm256_storeu_ps((float *)(b11 + cs_b*2), ymm7); - _mm256_storeu_ps((float *)(b11 + cs_b*3), ymm9); - _mm256_storeu_ps((float *)(b11 + cs_b*4), ymm11); - _mm256_storeu_ps((float *)(b11 + cs_b*5), ymm13); + xmm5 = _mm256_extractf128_ps(ymm3, 0); + _mm_storel_pi((__m64 *)(b11),xmm5); + _mm_store_ss((float *)(b11+2),_mm256_extractf128_ps(_mm256_unpackhi_ps(ymm3,ymm3), 0)); + + xmm5 = _mm256_extractf128_ps(ymm5, 0); + _mm_storel_pi((__m64 *)(b11 + cs_b),xmm5); + _mm_store_ss((float *)(b11+ 2 + cs_b),_mm256_extractf128_ps(_mm256_unpackhi_ps(ymm5,ymm5), 0)); + + xmm5 = _mm256_extractf128_ps(ymm7, 0); + _mm_storel_pi((__m64 *)(b11 + cs_b*2),xmm5); + _mm_store_ss((float *)(b11 + 2 + cs_b*2),_mm256_extractf128_ps(_mm256_unpackhi_ps(ymm7,ymm7), 0)); + + xmm5 = _mm256_extractf128_ps(ymm9, 0); + _mm_storel_pi((__m64 *)(b11 + cs_b*3),xmm5); + _mm_store_ss((float *)(b11 + 2 + cs_b*3),_mm256_extractf128_ps(_mm256_unpackhi_ps(ymm9,ymm9), 0)); + + xmm5 = _mm256_extractf128_ps(ymm11, 0); + _mm_storel_pi((__m64 *)(b11 + cs_b*4),xmm5); + _mm_store_ss((float *)(b11 + 2 + cs_b*4),_mm256_extractf128_ps(_mm256_unpackhi_ps(ymm11,ymm11), 0)); + + xmm5 = _mm256_extractf128_ps(ymm13, 0); + _mm_storel_pi((__m64 *)(b11 + cs_b*5),xmm5); + _mm_store_ss((float *)(b11 + 2 + cs_b*5),_mm256_extractf128_ps(_mm256_unpackhi_ps(ymm13,ymm13), 0)); m_remainder -=3; } @@ -16465,10 +18699,10 @@ BLIS_INLINE err_t bli_strsm_small_XAutB_XAlB BLIS_SET_S_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_6nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_6nx2m(a01,b10,cs_b,p_lda,k_iter) // Load b11 of size 4x6 and multiply with alpha - BLIS_PRE_STRSM_SMALL_6x8(AlphaVal,b11,cs_b) + BLIS_PRE_STRSM_SMALL_6x2(AlphaVal,b11,cs_b) ///implement TRSM/// @@ -16551,25 +18785,23 @@ BLIS_INLINE err_t bli_strsm_small_XAutB_XAlB ymm3 = STRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - ymm0 = _mm256_loadu_ps((float const *)b11); - ymm3 = _mm256_blend_ps(ymm0, ymm3, 0x03); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_ps(ymm0, ymm5, 0x03); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_ps(ymm0, ymm7, 0x03); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm9 = _mm256_blend_ps(ymm0, ymm9, 0x03); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*4)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm11 = _mm256_blend_ps(ymm0, ymm11, 0x03); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*5)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm13 = _mm256_blend_ps(ymm0, ymm13, 0x03); - - _mm256_storeu_ps((float *)b11, ymm3); - _mm256_storeu_ps((float *)(b11 + cs_b), ymm5); - _mm256_storeu_ps((float *)(b11 + cs_b*2), ymm7); - _mm256_storeu_ps((float *)(b11 + cs_b*3), ymm9); - _mm256_storeu_ps((float *)(b11 + cs_b*4), ymm11); - _mm256_storeu_ps((float *)(b11 + cs_b*5), ymm13); + xmm5 = _mm256_extractf128_ps(ymm3, 0); + _mm_storel_pi((__m64 *)(b11),xmm5); + + xmm5 = _mm256_extractf128_ps(ymm5, 0); + _mm_storel_pi((__m64 *)(b11 + cs_b),xmm5); + + xmm5 = _mm256_extractf128_ps(ymm7, 0); + _mm_storel_pi((__m64 *)(b11 + cs_b*2),xmm5); + + xmm5 = _mm256_extractf128_ps(ymm9, 0); + _mm_storel_pi((__m64 *)(b11 + cs_b*3),xmm5); + + xmm5 = _mm256_extractf128_ps(ymm11, 0); + _mm_storel_pi((__m64 *)(b11 + cs_b*4),xmm5); + + xmm5 = _mm256_extractf128_ps(ymm13, 0); + _mm_storel_pi((__m64 *)(b11 + cs_b*5),xmm5); m_remainder -=2; } @@ -16586,10 +18818,10 @@ BLIS_INLINE err_t bli_strsm_small_XAutB_XAlB BLIS_SET_S_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_6nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_6nx1m(a01,b10,cs_b,p_lda,k_iter) // Load b11 of size 4x6 and multiply with alpha - BLIS_PRE_STRSM_SMALL_6x8(AlphaVal,b11,cs_b) + BLIS_PRE_STRSM_SMALL_6x1(AlphaVal,b11,cs_b) ///implement TRSM/// @@ -16672,25 +18904,12 @@ BLIS_INLINE err_t bli_strsm_small_XAutB_XAlB ymm3 = STRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - ymm0 = _mm256_loadu_ps((float const *)b11); - ymm3 = _mm256_blend_ps(ymm0, ymm3, 0x01); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_ps(ymm0, ymm5, 0x01); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_ps(ymm0, ymm7, 0x01); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm9 = _mm256_blend_ps(ymm0, ymm9, 0x01); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*4)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm11 = _mm256_blend_ps(ymm0, ymm11, 0x01); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*5)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm13 = _mm256_blend_ps(ymm0, ymm13, 0x01); - - _mm256_storeu_ps((float *)b11, ymm3); - _mm256_storeu_ps((float *)(b11 + cs_b), ymm5); - _mm256_storeu_ps((float *)(b11 + cs_b*2), ymm7); - _mm256_storeu_ps((float *)(b11 + cs_b*3), ymm9); - _mm256_storeu_ps((float *)(b11 + cs_b*4), ymm11); - _mm256_storeu_ps((float *)(b11 + cs_b*5), ymm13); + _mm_store_ss((b11 + cs_b * 0), _mm256_extractf128_ps(ymm3, 0)); + _mm_store_ss((b11 + cs_b * 1), _mm256_extractf128_ps(ymm5, 0)); + _mm_store_ss((b11 + cs_b * 2), _mm256_extractf128_ps(ymm7, 0)); + _mm_store_ss((b11 + cs_b * 3), _mm256_extractf128_ps(ymm9, 0)); + _mm_store_ss((b11 + cs_b * 4), _mm256_extractf128_ps(ymm11, 0)); + _mm_store_ss((b11 + cs_b * 5), _mm256_extractf128_ps(ymm13, 0)); m_remainder -=1; } @@ -17009,7 +19228,7 @@ BLIS_INLINE err_t bli_strsm_small_XAutB_XAlB BLIS_SET_S_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_4nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_4nx7m(a01,b10,cs_b,p_lda,k_iter) ymm15 = _mm256_broadcast_ss((float const *)&AlphaVal); //register to hold alpha @@ -17123,7 +19342,7 @@ BLIS_INLINE err_t bli_strsm_small_XAutB_XAlB BLIS_SET_S_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_4nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_4nx6m(a01,b10,cs_b,p_lda,k_iter) ymm15 = _mm256_broadcast_ss((float const *)&AlphaVal); //register to hold alpha @@ -17226,7 +19445,7 @@ BLIS_INLINE err_t bli_strsm_small_XAutB_XAlB BLIS_SET_S_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_4nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_4nx5m(a01,b10,cs_b,p_lda,k_iter) ymm15 = _mm256_broadcast_ss((float const *)&AlphaVal); //register to hold alpha @@ -17317,7 +19536,7 @@ BLIS_INLINE err_t bli_strsm_small_XAutB_XAlB BLIS_SET_S_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_4nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_4nx4m(a01,b10,cs_b,p_lda,k_iter) ymm15 = _mm256_broadcast_ss((float const *)&AlphaVal); //register to hold alpha @@ -17400,7 +19619,7 @@ BLIS_INLINE err_t bli_strsm_small_XAutB_XAlB BLIS_SET_S_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_4nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_4nx3m(a01,b10,cs_b,p_lda,k_iter) ymm15 = _mm256_broadcast_ss((float const *)&AlphaVal); //register to hold alpha @@ -17498,7 +19717,7 @@ BLIS_INLINE err_t bli_strsm_small_XAutB_XAlB BLIS_SET_S_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_4nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_4nx2m(a01,b10,cs_b,p_lda,k_iter) ymm15 = _mm256_broadcast_ss((float const *)&AlphaVal); //register to hold alpha @@ -17589,7 +19808,7 @@ BLIS_INLINE err_t bli_strsm_small_XAutB_XAlB BLIS_SET_S_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_4nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_4nx1m(a01,b10,cs_b,p_lda,k_iter) ymm15 = _mm256_broadcast_ss((float const *)&AlphaVal); //register to hold alpha @@ -17932,7 +20151,7 @@ BLIS_INLINE err_t bli_strsm_small_XAutB_XAlB BLIS_SET_S_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_3nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_3nx7m(a01,b10,cs_b,p_lda,k_iter) BLIS_PRE_STRSM_SMALL_3N_7M(AlphaVal,b11,cs_b) @@ -17979,7 +20198,7 @@ BLIS_INLINE err_t bli_strsm_small_XAutB_XAlB BLIS_SET_S_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_3nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_3nx6m(a01,b10,cs_b,p_lda,k_iter) BLIS_PRE_STRSM_SMALL_3N_6M(AlphaVal,b11,cs_b) @@ -18026,7 +20245,7 @@ BLIS_INLINE err_t bli_strsm_small_XAutB_XAlB BLIS_SET_S_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_3nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_3nx5m(a01,b10,cs_b,p_lda,k_iter) BLIS_PRE_STRSM_SMALL_3N_5M(AlphaVal,b11,cs_b) @@ -18074,7 +20293,7 @@ BLIS_INLINE err_t bli_strsm_small_XAutB_XAlB BLIS_SET_S_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_3nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_3nx4m(a01,b10,cs_b,p_lda,k_iter) BLIS_PRE_STRSM_SMALL_3N_4M(AlphaVal,b11,cs_b) @@ -18121,7 +20340,7 @@ BLIS_INLINE err_t bli_strsm_small_XAutB_XAlB BLIS_SET_S_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_3nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_3nx3m(a01,b10,cs_b,p_lda,k_iter) BLIS_PRE_STRSM_SMALL_3N_3M(AlphaVal,b11,cs_b) @@ -18168,7 +20387,7 @@ BLIS_INLINE err_t bli_strsm_small_XAutB_XAlB BLIS_SET_S_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_3nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_3nx2m(a01,b10,cs_b,p_lda,k_iter) BLIS_PRE_STRSM_SMALL_3N_2M(AlphaVal,b11,cs_b) @@ -18216,7 +20435,7 @@ BLIS_INLINE err_t bli_strsm_small_XAutB_XAlB BLIS_SET_S_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_3nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_3nx1m(a01,b10,cs_b,p_lda,k_iter) BLIS_PRE_STRSM_SMALL_3N_1M(AlphaVal,b11,cs_b) @@ -18478,7 +20697,7 @@ BLIS_INLINE err_t bli_strsm_small_XAutB_XAlB BLIS_SET_S_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_2nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_2nx7m(a01,b10,cs_b,p_lda,k_iter) BLIS_PRE_STRSM_SMALL_2N_7M(AlphaVal,b11,cs_b) @@ -18514,7 +20733,7 @@ BLIS_INLINE err_t bli_strsm_small_XAutB_XAlB BLIS_SET_S_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_2nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_2nx6m(a01,b10,cs_b,p_lda,k_iter) BLIS_PRE_STRSM_SMALL_2N_6M(AlphaVal,b11,cs_b) @@ -18550,7 +20769,7 @@ BLIS_INLINE err_t bli_strsm_small_XAutB_XAlB BLIS_SET_S_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_2nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_2nx5m(a01,b10,cs_b,p_lda,k_iter) BLIS_PRE_STRSM_SMALL_2N_5M(AlphaVal,b11,cs_b) @@ -18586,7 +20805,7 @@ BLIS_INLINE err_t bli_strsm_small_XAutB_XAlB BLIS_SET_S_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_2nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_2nx4m(a01,b10,cs_b,p_lda,k_iter) BLIS_PRE_STRSM_SMALL_2N_4M(AlphaVal,b11,cs_b) @@ -18622,7 +20841,7 @@ BLIS_INLINE err_t bli_strsm_small_XAutB_XAlB BLIS_SET_S_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_2nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_2nx3m(a01,b10,cs_b,p_lda,k_iter) BLIS_PRE_STRSM_SMALL_2N_3M(AlphaVal,b11,cs_b) @@ -18658,7 +20877,7 @@ BLIS_INLINE err_t bli_strsm_small_XAutB_XAlB BLIS_SET_S_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_2nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_2nx2m(a01,b10,cs_b,p_lda,k_iter) BLIS_PRE_STRSM_SMALL_2N_2M(AlphaVal,b11,cs_b) ///implement TRSM/// @@ -18693,7 +20912,7 @@ BLIS_INLINE err_t bli_strsm_small_XAutB_XAlB BLIS_SET_S_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_2nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_2nx1m(a01,b10,cs_b,p_lda,k_iter) BLIS_PRE_STRSM_SMALL_2N_1M(AlphaVal,b11,cs_b) ///implement TRSM/// @@ -18733,7 +20952,8 @@ BLIS_INLINE err_t bli_strsm_small_XAutB_XAlB __m128 xmm0, xmm1, xmm2, xmm3; __m128 xmm4, xmm5, xmm6, xmm7; __m128 xmm8, xmm9; - for(dim_t x =0;x < p_lda;x+=d_nr) + dim_t x = 0; + for(x =0;(x+d_nr-1) < p_lda;x+=d_nr) { xmm0 = _mm_loadu_ps((float const *)(a01)); xmm1 = _mm_loadu_ps((float const *)(a01 + cs_a)); @@ -18776,6 +20996,33 @@ BLIS_INLINE err_t bli_strsm_small_XAutB_XAlB a01 += d_nr*cs_a; ptr_a10_dup += d_nr; } + dim_t remainder_count = p_lda - x; + if(remainder_count >= 4) + { + xmm0 = _mm_loadu_ps((float const *)(a01)); + xmm1 = _mm_loadu_ps((float const *)(a01 + cs_a)); + xmm2 = _mm_loadu_ps((float const *)(a01 + cs_a * 2)); + xmm3 = _mm_loadu_ps((float const *)(a01 + cs_a * 3)); + + xmm4 = _mm_unpacklo_ps(xmm0, xmm1); + xmm5 = _mm_unpacklo_ps(xmm2, xmm3); + xmm6 = _mm_shuffle_ps(xmm4,xmm5,0x44); + xmm7 = _mm_shuffle_ps(xmm4,xmm5,0xEE); + + xmm0 = _mm_unpackhi_ps(xmm0, xmm1); + xmm1 = _mm_unpackhi_ps(xmm2, xmm3); + xmm8 = _mm_shuffle_ps(xmm0,xmm1,0x44); + xmm9 = _mm_shuffle_ps(xmm0,xmm1,0xEE); + + _mm_storeu_ps((float *)(ptr_a10_dup), xmm6); + _mm_storeu_ps((float *)(ptr_a10_dup + p_lda), xmm7); + _mm_storeu_ps((float *)(ptr_a10_dup + p_lda*2), xmm8); + _mm_storeu_ps((float *)(ptr_a10_dup + p_lda*3), xmm9); + + a01 += 4*cs_a; + ptr_a10_dup += 4; + remainder_count = remainder_count - 4; + } } else { @@ -18797,8 +21044,6 @@ BLIS_INLINE err_t bli_strsm_small_XAutB_XAlB xmm0 = _mm_loadu_ps((float *)(a01 + rs_a * 0 + loop_count*6)); _mm_storeu_ps((float *)(ptr_a10_dup + p_lda * 0 + loop_count*6), xmm0); - xmm0 = _mm_loadl_pi(xmm1,(__m64 *)(a01 + rs_a * 0 + 4 + loop_count*6)); - _mm_storel_pi((__m64 *)(ptr_a10_dup + p_lda * 0 + 4 + loop_count*6),xmm0); } } @@ -18900,7 +21145,7 @@ BLIS_INLINE err_t bli_strsm_small_XAutB_XAlB ymm3 = _mm256_setzero_ps(); ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_1nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_1nx7m(a01,b10,cs_b,p_lda,k_iter) BLIS_PRE_STRSM_SMALL_1N_7M(AlphaVal,b11,cs_b) @@ -18926,7 +21171,7 @@ BLIS_INLINE err_t bli_strsm_small_XAutB_XAlB ymm3 = _mm256_setzero_ps(); ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_1nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_1nx6m(a01,b10,cs_b,p_lda,k_iter) BLIS_PRE_STRSM_SMALL_1N_6M(AlphaVal,b11,cs_b) @@ -18952,7 +21197,7 @@ BLIS_INLINE err_t bli_strsm_small_XAutB_XAlB ymm3 = _mm256_setzero_ps(); ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_1nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_1nx5m(a01,b10,cs_b,p_lda,k_iter) BLIS_PRE_STRSM_SMALL_1N_5M(AlphaVal,b11,cs_b) @@ -18977,7 +21222,7 @@ BLIS_INLINE err_t bli_strsm_small_XAutB_XAlB ymm3 = _mm256_setzero_ps(); ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_1nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_1nx4m(a01,b10,cs_b,p_lda,k_iter) BLIS_PRE_STRSM_SMALL_1N_4M(AlphaVal,b11,cs_b) @@ -19002,7 +21247,7 @@ BLIS_INLINE err_t bli_strsm_small_XAutB_XAlB ymm3 = _mm256_setzero_ps(); ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_1nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_1nx3m(a01,b10,cs_b,p_lda,k_iter) BLIS_PRE_STRSM_SMALL_1N_3M(AlphaVal,b11,cs_b) @@ -19027,7 +21272,7 @@ BLIS_INLINE err_t bli_strsm_small_XAutB_XAlB ymm3 = _mm256_setzero_ps(); ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_1nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_1nx2m(a01,b10,cs_b,p_lda,k_iter) BLIS_PRE_STRSM_SMALL_1N_2M(AlphaVal,b11,cs_b) @@ -19052,7 +21297,7 @@ BLIS_INLINE err_t bli_strsm_small_XAutB_XAlB ymm3 = _mm256_setzero_ps(); ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_1nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_1nx1m(a01,b10,cs_b,p_lda,k_iter) BLIS_PRE_STRSM_SMALL_1N_1M(AlphaVal,b11,cs_b) @@ -19192,6 +21437,8 @@ BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB __m128 xmm5; + xmm5 = _mm_setzero_ps(); + /* Performs solving TRSM for 6 rows at a time from 0 to n/6 in steps of d_nr a. Load and pack A (a01 block), the size of packing 6x6 to 6x (n-6) @@ -19550,10 +21797,10 @@ BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB BLIS_SET_S_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_6nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_6nx7m(a01,b10,cs_b,p_lda,k_iter) // Load b11 of size 8x6 and multiply with alpha - BLIS_PRE_STRSM_SMALL_6x8(AlphaVal,b11,cs_b) + BLIS_PRE_STRSM_SMALL_6x7(AlphaVal,b11,cs_b) ///implement TRSM/// @@ -19644,25 +21891,29 @@ BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB ymm13 = STRSM_SMALL_DIV_OR_SCALE(ymm13, ymm0); - ymm0 = _mm256_loadu_ps((float const *)b11); - ymm3 = _mm256_blend_ps(ymm0, ymm3, 0x7F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_ps(ymm0, ymm5, 0x7F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_ps(ymm0, ymm7, 0x7F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm9 = _mm256_blend_ps(ymm0, ymm9, 0x7F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*4)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm11 = _mm256_blend_ps(ymm0, ymm11, 0x7F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*5)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm13 = _mm256_blend_ps(ymm0, ymm13, 0x7F); + _mm_storeu_ps((float *)(b11),_mm256_extractf128_ps(ymm3, 0)); + _mm_storel_pi((__m64 *)(b11 + 4),_mm256_extractf128_ps(ymm3, 1)); + _mm_store_ss((float *)(b11 + 6),_mm256_extractf128_ps(_mm256_unpackhi_ps(ymm3,ymm3), 1)); - _mm256_storeu_ps((float *)b11, ymm3); - _mm256_storeu_ps((float *)(b11 + cs_b), ymm5); - _mm256_storeu_ps((float *)(b11 + cs_b*2), ymm7); - _mm256_storeu_ps((float *)(b11 + cs_b*3), ymm9); - _mm256_storeu_ps((float *)(b11 + cs_b*4), ymm11); - _mm256_storeu_ps((float *)(b11 + cs_b*5), ymm13); + _mm_storeu_ps((float *)(b11 + cs_b),_mm256_extractf128_ps(ymm5, 0)); + _mm_storel_pi((__m64 *)(b11 + 4 + cs_b),_mm256_extractf128_ps(ymm5, 1)); + _mm_store_ss((float *)(b11 + 6 + cs_b),_mm256_extractf128_ps(_mm256_unpackhi_ps(ymm5,ymm5), 1)); + + _mm_storeu_ps((float *)(b11 + cs_b*2),_mm256_extractf128_ps(ymm7, 0)); + _mm_storel_pi((__m64 *)(b11 + 4 + cs_b*2),_mm256_extractf128_ps(ymm7, 1)); + _mm_store_ss((float *)(b11 + 6 + cs_b*2),_mm256_extractf128_ps(_mm256_unpackhi_ps(ymm7,ymm7), 1)); + + _mm_storeu_ps((float *)(b11 + cs_b*3),_mm256_extractf128_ps(ymm9, 0)); + _mm_storel_pi((__m64 *)(b11 + 4 + cs_b*3),_mm256_extractf128_ps(ymm9, 1)); + _mm_store_ss((float *)(b11 + 6 + cs_b*3),_mm256_extractf128_ps(_mm256_unpackhi_ps(ymm9,ymm9), 1)); + + _mm_storeu_ps((float *)(b11 + cs_b*4),_mm256_extractf128_ps(ymm11, 0)); + _mm_storel_pi((__m64 *)(b11 + 4 + cs_b*4),_mm256_extractf128_ps(ymm11, 1)); + _mm_store_ss((float *)(b11 + 6 + cs_b*4),_mm256_extractf128_ps(_mm256_unpackhi_ps(ymm11,ymm11), 1)); + + _mm_storeu_ps((float *)(b11 + cs_b*5),_mm256_extractf128_ps(ymm13, 0)); + _mm_storel_pi((__m64 *)(b11 + 4 + cs_b*5),_mm256_extractf128_ps(ymm13, 1)); + _mm_store_ss((float *)(b11 + 6 + cs_b*5),_mm256_extractf128_ps(_mm256_unpackhi_ps(ymm13,ymm13), 1)); m_remainder -= 7; i += 7; @@ -19680,10 +21931,10 @@ BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB BLIS_SET_S_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_6nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_6nx6m(a01,b10,cs_b,p_lda,k_iter) // Load b11 of size 8x6 and multiply with alpha - BLIS_PRE_STRSM_SMALL_6x8(AlphaVal,b11,cs_b) + BLIS_PRE_STRSM_SMALL_6x6(AlphaVal,b11,cs_b) ///implement TRSM/// @@ -19774,25 +22025,18 @@ BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB ymm13 = STRSM_SMALL_DIV_OR_SCALE(ymm13, ymm0); - ymm0 = _mm256_loadu_ps((float const *)b11); - ymm3 = _mm256_blend_ps(ymm0, ymm3, 0x3F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_ps(ymm0, ymm5, 0x3F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_ps(ymm0, ymm7, 0x3F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm9 = _mm256_blend_ps(ymm0, ymm9, 0x3F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*4)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm11 = _mm256_blend_ps(ymm0, ymm11, 0x3F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*5)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm13 = _mm256_blend_ps(ymm0, ymm13, 0x3F); - - _mm256_storeu_ps((float *)b11, ymm3); - _mm256_storeu_ps((float *)(b11 + cs_b), ymm5); - _mm256_storeu_ps((float *)(b11 + cs_b*2), ymm7); - _mm256_storeu_ps((float *)(b11 + cs_b*3), ymm9); - _mm256_storeu_ps((float *)(b11 + cs_b*4), ymm11); - _mm256_storeu_ps((float *)(b11 + cs_b*5), ymm13); + _mm_storeu_ps((float *)(b11),_mm256_extractf128_ps(ymm3, 0)); + _mm_storel_pi((__m64 *)(b11 + 4),_mm256_extractf128_ps(ymm3, 1)); + _mm_storeu_ps((float *)(b11 + cs_b),_mm256_extractf128_ps(ymm5, 0)); + _mm_storel_pi((__m64 *)(b11 + 4 + cs_b),_mm256_extractf128_ps(ymm5, 1)); + _mm_storeu_ps((float *)(b11 + cs_b*2),_mm256_extractf128_ps(ymm7, 0)); + _mm_storel_pi((__m64 *)(b11 + 4 + cs_b*2),_mm256_extractf128_ps(ymm7, 1)); + _mm_storeu_ps((float *)(b11 + cs_b*3),_mm256_extractf128_ps(ymm9, 0)); + _mm_storel_pi((__m64 *)(b11 + 4 + cs_b*3),_mm256_extractf128_ps(ymm9, 1)); + _mm_storeu_ps((float *)(b11 + cs_b*4),_mm256_extractf128_ps(ymm11, 0)); + _mm_storel_pi((__m64 *)(b11 + 4 + cs_b*4),_mm256_extractf128_ps(ymm11, 1)); + _mm_storeu_ps((float *)(b11 + cs_b*5),_mm256_extractf128_ps(ymm13, 0)); + _mm_storel_pi((__m64 *)(b11 + 4 + cs_b*5),_mm256_extractf128_ps(ymm13, 1)); m_remainder -= 6; i += 6; @@ -19810,10 +22054,10 @@ BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB BLIS_SET_S_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_6nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_6nx5m(a01,b10,cs_b,p_lda,k_iter) // Load b11 of size 8x6 and multiply with alpha - BLIS_PRE_STRSM_SMALL_6x8(AlphaVal,b11,cs_b) + BLIS_PRE_STRSM_SMALL_6x5(AlphaVal,b11,cs_b) ///implement TRSM/// @@ -19904,25 +22148,18 @@ BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB ymm13 = STRSM_SMALL_DIV_OR_SCALE(ymm13, ymm0); - ymm0 = _mm256_loadu_ps((float const *)b11); - ymm3 = _mm256_blend_ps(ymm0, ymm3, 0x1F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_ps(ymm0, ymm5, 0x1F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_ps(ymm0, ymm7, 0x1F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm9 = _mm256_blend_ps(ymm0, ymm9, 0x1F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*4)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm11 = _mm256_blend_ps(ymm0, ymm11, 0x1F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*5)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm13 = _mm256_blend_ps(ymm0, ymm13, 0x1F); - - _mm256_storeu_ps((float *)b11, ymm3); - _mm256_storeu_ps((float *)(b11 + cs_b), ymm5); - _mm256_storeu_ps((float *)(b11 + cs_b*2), ymm7); - _mm256_storeu_ps((float *)(b11 + cs_b*3), ymm9); - _mm256_storeu_ps((float *)(b11 + cs_b*4), ymm11); - _mm256_storeu_ps((float *)(b11 + cs_b*5), ymm13); + _mm_storeu_ps((float *)(b11),_mm256_extractf128_ps(ymm3, 0)); + _mm_store_ss((float *)(b11 + 4),_mm256_extractf128_ps(ymm3, 1)); + _mm_storeu_ps((float *)(b11 + cs_b),_mm256_extractf128_ps(ymm5, 0)); + _mm_store_ss((float *)(b11 + 4 + cs_b),_mm256_extractf128_ps(ymm5, 1)); + _mm_storeu_ps((float *)(b11 + cs_b*2),_mm256_extractf128_ps(ymm7, 0)); + _mm_store_ss((float *)(b11 + 4 + cs_b*2),_mm256_extractf128_ps(ymm7, 1)); + _mm_storeu_ps((float *)(b11 + cs_b*3),_mm256_extractf128_ps(ymm9, 0)); + _mm_store_ss((float *)(b11 + 4 + cs_b*3),_mm256_extractf128_ps(ymm9, 1)); + _mm_storeu_ps((float *)(b11 + cs_b*4),_mm256_extractf128_ps(ymm11, 0)); + _mm_store_ss((float *)(b11 + 4 + cs_b*4),_mm256_extractf128_ps(ymm11, 1)); + _mm_storeu_ps((float *)(b11 + cs_b*5),_mm256_extractf128_ps(ymm13, 0)); + _mm_store_ss((float *)(b11 + 4 + cs_b*5),_mm256_extractf128_ps(ymm13, 1)); m_remainder -= 5; i += 5; @@ -19940,10 +22177,10 @@ BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB BLIS_SET_S_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_6nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_6nx4m(a01,b10,cs_b,p_lda,k_iter) // Load b11 of size 8x6 and multiply with alpha - BLIS_PRE_STRSM_SMALL_6x8(AlphaVal,b11,cs_b) + BLIS_PRE_STRSM_SMALL_6x4(AlphaVal,b11,cs_b) ///implement TRSM/// @@ -20034,25 +22271,12 @@ BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB ymm13 = STRSM_SMALL_DIV_OR_SCALE(ymm13, ymm0); - ymm0 = _mm256_loadu_ps((float const *)b11); - ymm3 = _mm256_blend_ps(ymm0, ymm3, 0x0F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_ps(ymm0, ymm5, 0x0F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_ps(ymm0, ymm7, 0x0F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm9 = _mm256_blend_ps(ymm0, ymm9, 0x0F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*4)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm11 = _mm256_blend_ps(ymm0, ymm11, 0x0F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*5)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm13 = _mm256_blend_ps(ymm0, ymm13, 0x0F); - - _mm256_storeu_ps((float *)b11, ymm3); - _mm256_storeu_ps((float *)(b11 + cs_b), ymm5); - _mm256_storeu_ps((float *)(b11 + cs_b*2), ymm7); - _mm256_storeu_ps((float *)(b11 + cs_b*3), ymm9); - _mm256_storeu_ps((float *)(b11 + cs_b*4), ymm11); - _mm256_storeu_ps((float *)(b11 + cs_b*5), ymm13); + _mm_storeu_ps((float *)(b11),_mm256_extractf128_ps(ymm3, 0)); + _mm_storeu_ps((float *)(b11 + cs_b),_mm256_extractf128_ps(ymm5, 0)); + _mm_storeu_ps((float *)(b11 + cs_b*2),_mm256_extractf128_ps(ymm7, 0)); + _mm_storeu_ps((float *)(b11 + cs_b*3),_mm256_extractf128_ps(ymm9, 0)); + _mm_storeu_ps((float *)(b11 + cs_b*4),_mm256_extractf128_ps(ymm11, 0)); + _mm_storeu_ps((float *)(b11 + cs_b*5),_mm256_extractf128_ps(ymm13, 0)); m_remainder -= 4; i += 4; @@ -20070,10 +22294,10 @@ BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB BLIS_SET_S_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_6nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_6nx3m(a01,b10,cs_b,p_lda,k_iter) // Load b11 of size 8x6 and multiply with alpha - BLIS_PRE_STRSM_SMALL_6x8(AlphaVal,b11,cs_b) + BLIS_PRE_STRSM_SMALL_6x3(AlphaVal,b11,cs_b) ///implement TRSM/// @@ -20164,25 +22388,29 @@ BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB ymm13 = STRSM_SMALL_DIV_OR_SCALE(ymm13, ymm0); - ymm0 = _mm256_loadu_ps((float const *)b11); - ymm3 = _mm256_blend_ps(ymm0, ymm3, 0x07); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_ps(ymm0, ymm5, 0x07); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_ps(ymm0, ymm7, 0x07); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm9 = _mm256_blend_ps(ymm0, ymm9, 0x07); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*4)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm11 = _mm256_blend_ps(ymm0, ymm11, 0x07); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*5)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm13 = _mm256_blend_ps(ymm0, ymm13, 0x07); + xmm5 = _mm256_extractf128_ps(ymm3, 0); + _mm_storel_pi((__m64 *)(b11),xmm5); + _mm_store_ss((float *)(b11+2),_mm256_extractf128_ps(_mm256_unpackhi_ps(ymm3,ymm3), 0)); - _mm256_storeu_ps((float *)b11, ymm3); - _mm256_storeu_ps((float *)(b11 + cs_b), ymm5); - _mm256_storeu_ps((float *)(b11 + cs_b*2), ymm7); - _mm256_storeu_ps((float *)(b11 + cs_b*3), ymm9); - _mm256_storeu_ps((float *)(b11 + cs_b*4), ymm11); - _mm256_storeu_ps((float *)(b11 + cs_b*5), ymm13); + xmm5 = _mm256_extractf128_ps(ymm5, 0); + _mm_storel_pi((__m64 *)(b11 + cs_b),xmm5); + _mm_store_ss((float *)(b11+ 2 + cs_b),_mm256_extractf128_ps(_mm256_unpackhi_ps(ymm5,ymm5), 0)); + + xmm5 = _mm256_extractf128_ps(ymm7, 0); + _mm_storel_pi((__m64 *)(b11 + cs_b*2),xmm5); + _mm_store_ss((float *)(b11 + 2 + cs_b*2),_mm256_extractf128_ps(_mm256_unpackhi_ps(ymm7,ymm7), 0)); + + xmm5 = _mm256_extractf128_ps(ymm9, 0); + _mm_storel_pi((__m64 *)(b11 + cs_b*3),xmm5); + _mm_store_ss((float *)(b11 + 2 + cs_b*3),_mm256_extractf128_ps(_mm256_unpackhi_ps(ymm9,ymm9), 0)); + + xmm5 = _mm256_extractf128_ps(ymm11, 0); + _mm_storel_pi((__m64 *)(b11 + cs_b*4),xmm5); + _mm_store_ss((float *)(b11 + 2 + cs_b*4),_mm256_extractf128_ps(_mm256_unpackhi_ps(ymm11,ymm11), 0)); + + xmm5 = _mm256_extractf128_ps(ymm13, 0); + _mm_storel_pi((__m64 *)(b11 + cs_b*5),xmm5); + _mm_store_ss((float *)(b11 + 2 + cs_b*5),_mm256_extractf128_ps(_mm256_unpackhi_ps(ymm13,ymm13), 0)); m_remainder -= 3; i += 3; @@ -20200,10 +22428,10 @@ BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB BLIS_SET_S_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_6nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_6nx2m(a01,b10,cs_b,p_lda,k_iter) // Load b11 of size 4x6 and multiply with alpha - BLIS_PRE_STRSM_SMALL_6x8(AlphaVal,b11,cs_b) + BLIS_PRE_STRSM_SMALL_6x2(AlphaVal,b11,cs_b) ///implement TRSM/// @@ -20294,25 +22522,23 @@ BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB ymm13 = STRSM_SMALL_DIV_OR_SCALE(ymm13, ymm0); - ymm0 = _mm256_loadu_ps((float const *)b11); - ymm3 = _mm256_blend_ps(ymm0, ymm3, 0x03); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_ps(ymm0, ymm5, 0x03); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_ps(ymm0, ymm7, 0x03); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm9 = _mm256_blend_ps(ymm0, ymm9, 0x03); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*4)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm11 = _mm256_blend_ps(ymm0, ymm11, 0x03); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*5)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm13 = _mm256_blend_ps(ymm0, ymm13, 0x03); + xmm5 = _mm256_extractf128_ps(ymm3, 0); + _mm_storel_pi((__m64 *)(b11),xmm5); - _mm256_storeu_ps((float *)b11, ymm3); - _mm256_storeu_ps((float *)(b11 + cs_b), ymm5); - _mm256_storeu_ps((float *)(b11 + cs_b*2), ymm7); - _mm256_storeu_ps((float *)(b11 + cs_b*3), ymm9); - _mm256_storeu_ps((float *)(b11 + cs_b*4), ymm11); - _mm256_storeu_ps((float *)(b11 + cs_b*5), ymm13); + xmm5 = _mm256_extractf128_ps(ymm5, 0); + _mm_storel_pi((__m64 *)(b11 + cs_b),xmm5); + + xmm5 = _mm256_extractf128_ps(ymm7, 0); + _mm_storel_pi((__m64 *)(b11 + cs_b*2),xmm5); + + xmm5 = _mm256_extractf128_ps(ymm9, 0); + _mm_storel_pi((__m64 *)(b11 + cs_b*3),xmm5); + + xmm5 = _mm256_extractf128_ps(ymm11, 0); + _mm_storel_pi((__m64 *)(b11 + cs_b*4),xmm5); + + xmm5 = _mm256_extractf128_ps(ymm13, 0); + _mm_storel_pi((__m64 *)(b11 + cs_b*5),xmm5); m_remainder -= 2; i += 2; @@ -20330,10 +22556,10 @@ BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB BLIS_SET_S_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_6nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_6nx1m(a01,b10,cs_b,p_lda,k_iter) // Load b11 of size 4x6 and multiply with alpha - BLIS_PRE_STRSM_SMALL_6x8(AlphaVal,b11,cs_b) + BLIS_PRE_STRSM_SMALL_6x1(AlphaVal,b11,cs_b) ///implement TRSM/// @@ -20424,25 +22650,12 @@ BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB ymm13 = STRSM_SMALL_DIV_OR_SCALE(ymm13, ymm0); - ymm0 = _mm256_loadu_ps((float const *)b11); - ymm3 = _mm256_blend_ps(ymm0, ymm3, 0x01); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_ps(ymm0, ymm5, 0x01); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_ps(ymm0, ymm7, 0x01); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm9 = _mm256_blend_ps(ymm0, ymm9, 0x01); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*4)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm11 = _mm256_blend_ps(ymm0, ymm11, 0x01); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*5)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm13 = _mm256_blend_ps(ymm0, ymm13, 0x01); - - _mm256_storeu_ps((float *)b11, ymm3); - _mm256_storeu_ps((float *)(b11 + cs_b), ymm5); - _mm256_storeu_ps((float *)(b11 + cs_b*2), ymm7); - _mm256_storeu_ps((float *)(b11 + cs_b*3), ymm9); - _mm256_storeu_ps((float *)(b11 + cs_b*4), ymm11); - _mm256_storeu_ps((float *)(b11 + cs_b*5), ymm13); + _mm_store_ss((b11 + cs_b * 0), _mm256_extractf128_ps(ymm3, 0)); + _mm_store_ss((b11 + cs_b * 1), _mm256_extractf128_ps(ymm5, 0)); + _mm_store_ss((b11 + cs_b * 2), _mm256_extractf128_ps(ymm7, 0)); + _mm_store_ss((b11 + cs_b * 3), _mm256_extractf128_ps(ymm9, 0)); + _mm_store_ss((b11 + cs_b * 4), _mm256_extractf128_ps(ymm11, 0)); + _mm_store_ss((b11 + cs_b * 5), _mm256_extractf128_ps(ymm13, 0)); m_remainder -= 1; i += 1; @@ -20769,7 +22982,7 @@ BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB BLIS_SET_S_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_4nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_4nx7m(a01,b10,cs_b,p_lda,k_iter) ymm15 = _mm256_broadcast_ss((float const *)(&AlphaVal)); //register to hold alpha @@ -20888,7 +23101,7 @@ BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB BLIS_SET_S_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_4nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_4nx6m(a01,b10,cs_b,p_lda,k_iter) ymm15 = _mm256_broadcast_ss((float const *)(&AlphaVal)); //register to hold alpha @@ -20996,7 +23209,7 @@ BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB BLIS_SET_S_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_4nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_4nx5m(a01,b10,cs_b,p_lda,k_iter) ymm15 = _mm256_broadcast_ss((float const *)(&AlphaVal)); //register to hold alpha @@ -21092,7 +23305,7 @@ BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB BLIS_SET_S_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_4nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_4nx4m(a01,b10,cs_b,p_lda,k_iter) ymm15 = _mm256_broadcast_ss((float const *)(&AlphaVal)); //register to hold alpha @@ -21180,7 +23393,7 @@ BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB BLIS_SET_S_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_4nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_4nx3m(a01,b10,cs_b,p_lda,k_iter) ymm15 = _mm256_broadcast_ss((float const *)(&AlphaVal)); //register to hold alpha @@ -21283,7 +23496,7 @@ BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB BLIS_SET_S_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_4nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_4nx2m(a01,b10,cs_b,p_lda,k_iter) ymm15 = _mm256_broadcast_ss((float const *)&AlphaVal); //register to hold alpha @@ -21379,7 +23592,7 @@ BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB BLIS_SET_S_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_4nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_4nx1m(a01,b10,cs_b,p_lda,k_iter) ymm15 = _mm256_broadcast_ss((float const *)&AlphaVal); //register to hold alpha @@ -21730,7 +23943,7 @@ BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB BLIS_SET_S_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_3nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_3nx7m(a01,b10,cs_b,p_lda,k_iter) BLIS_PRE_STRSM_SMALL_3N_7M(AlphaVal,b11,cs_b) @@ -21781,7 +23994,7 @@ BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB BLIS_SET_S_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_3nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_3nx6m(a01,b10,cs_b,p_lda,k_iter) BLIS_PRE_STRSM_SMALL_3N_6M(AlphaVal,b11,cs_b) @@ -21832,7 +24045,7 @@ BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB BLIS_SET_S_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_3nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_3nx5m(a01,b10,cs_b,p_lda,k_iter) BLIS_PRE_STRSM_SMALL_3N_5M(AlphaVal,b11,cs_b) @@ -21883,7 +24096,7 @@ BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB BLIS_SET_S_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_3nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_3nx4m(a01,b10,cs_b,p_lda,k_iter) BLIS_PRE_STRSM_SMALL_3N_4M(AlphaVal,b11,cs_b) @@ -21934,7 +24147,7 @@ BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB BLIS_SET_S_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_3nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_3nx3m(a01,b10,cs_b,p_lda,k_iter) BLIS_PRE_STRSM_SMALL_3N_3M(AlphaVal,b11,cs_b) @@ -21985,7 +24198,7 @@ BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB BLIS_SET_S_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_3nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_3nx2m(a01,b10,cs_b,p_lda,k_iter) BLIS_PRE_STRSM_SMALL_3N_2M(AlphaVal,b11,cs_b) @@ -22036,7 +24249,7 @@ BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB BLIS_SET_S_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_3nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_3nx1m(a01,b10,cs_b,p_lda,k_iter) BLIS_PRE_STRSM_SMALL_3N_1M(AlphaVal,b11,cs_b) @@ -22301,7 +24514,7 @@ BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB ymm5 = _mm256_setzero_ps(); ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_2nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_2nx7m(a01,b10,cs_b,p_lda,k_iter) BLIS_PRE_STRSM_SMALL_2N_7M(AlphaVal,b11,cs_b) @@ -22338,7 +24551,7 @@ BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB ymm5 = _mm256_setzero_ps(); ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_2nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_2nx6m(a01,b10,cs_b,p_lda,k_iter) BLIS_PRE_STRSM_SMALL_2N_6M(AlphaVal,b11,cs_b) @@ -22375,7 +24588,7 @@ BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB ymm5 = _mm256_setzero_ps(); ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_2nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_2nx5m(a01,b10,cs_b,p_lda,k_iter) BLIS_PRE_STRSM_SMALL_2N_5M(AlphaVal,b11,cs_b) @@ -22412,7 +24625,7 @@ BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB ymm5 = _mm256_setzero_ps(); ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_2nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_2nx4m(a01,b10,cs_b,p_lda,k_iter) BLIS_PRE_STRSM_SMALL_2N_4M(AlphaVal,b11,cs_b) @@ -22449,7 +24662,7 @@ BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB ymm5 = _mm256_setzero_ps(); ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_2nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_2nx3m(a01,b10,cs_b,p_lda,k_iter) BLIS_PRE_STRSM_SMALL_2N_3M(AlphaVal,b11,cs_b) @@ -22486,7 +24699,7 @@ BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB ymm5 = _mm256_setzero_ps(); ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_2nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_2nx2m(a01,b10,cs_b,p_lda,k_iter) BLIS_PRE_STRSM_SMALL_2N_2M(AlphaVal,b11,cs_b) @@ -22522,7 +24735,7 @@ BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB ymm5 = _mm256_setzero_ps(); ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_2nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_2nx1m(a01,b10,cs_b,p_lda,k_iter) BLIS_PRE_STRSM_SMALL_2N_1M(AlphaVal,b11,cs_b) @@ -22566,7 +24779,8 @@ BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB __m128 xmm4, xmm5, xmm6, xmm7; __m128 xmm8, xmm9; - for(dim_t x =0;x < p_lda;x+=d_nr) + dim_t x = 0; + for(x =0;(x+d_nr-1) < p_lda;x+=d_nr) { xmm0 = _mm_loadu_ps((float const *)(a01)); xmm1 = _mm_loadu_ps((float const *)(a01 + cs_a)); @@ -22609,6 +24823,32 @@ BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB a01 += d_nr*cs_a; ptr_a10_dup += d_nr; } + dim_t remainder_count = p_lda - x; + if(remainder_count >= 4) + { + xmm0 = _mm_loadu_ps((float const *)(a01)); + xmm1 = _mm_loadu_ps((float const *)(a01 + cs_a)); + xmm2 = _mm_loadu_ps((float const *)(a01 + cs_a * 2)); + xmm3 = _mm_loadu_ps((float const *)(a01 + cs_a * 3)); + + xmm4 = _mm_unpacklo_ps(xmm0, xmm1); + xmm5 = _mm_unpacklo_ps(xmm2, xmm3); + xmm6 = _mm_shuffle_ps(xmm4,xmm5,0x44); + xmm7 = _mm_shuffle_ps(xmm4,xmm5,0xEE); + + xmm0 = _mm_unpackhi_ps(xmm0, xmm1); + xmm1 = _mm_unpackhi_ps(xmm2, xmm3); + xmm8 = _mm_shuffle_ps(xmm0,xmm1,0x44); + xmm9 = _mm_shuffle_ps(xmm0,xmm1,0xEE); + + _mm_storeu_ps((float *)(ptr_a10_dup), xmm6); + _mm_storeu_ps((float *)(ptr_a10_dup + p_lda), xmm7); + _mm_storeu_ps((float *)(ptr_a10_dup + p_lda*2), xmm8); + _mm_storeu_ps((float *)(ptr_a10_dup + p_lda*3), xmm9); + + a01 += 4*cs_a; + ptr_a10_dup += 4; + } } else { @@ -22630,8 +24870,6 @@ BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB xmm0 = _mm_loadu_ps((float *)(a01 + rs_a * 0 + loop_count*6)); _mm_storeu_ps((float *)(ptr_a10_dup + p_lda * 0 + loop_count*6), xmm0); - xmm0 = _mm_loadl_pi(xmm1,(__m64 *)(a01 + rs_a * 0 + 4 + loop_count*6)); - _mm_storel_pi((__m64 *)(ptr_a10_dup + p_lda * 0 + 4 + loop_count*6),xmm0); } } @@ -22735,7 +24973,7 @@ BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB ymm3 = _mm256_setzero_ps(); ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_1nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_1nx7m(a01,b10,cs_b,p_lda,k_iter) BLIS_PRE_STRSM_SMALL_1N_7M(AlphaVal,b11,cs_b) @@ -22761,7 +24999,7 @@ BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB ymm3 = _mm256_setzero_ps(); ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_1nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_1nx6m(a01,b10,cs_b,p_lda,k_iter) BLIS_PRE_STRSM_SMALL_1N_6M(AlphaVal,b11,cs_b) @@ -22787,7 +25025,7 @@ BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB ymm3 = _mm256_setzero_ps(); ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_1nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_1nx5m(a01,b10,cs_b,p_lda,k_iter) BLIS_PRE_STRSM_SMALL_1N_5M(AlphaVal,b11,cs_b) @@ -22813,7 +25051,7 @@ BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB ymm3 = _mm256_setzero_ps(); ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_1nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_1nx4m(a01,b10,cs_b,p_lda,k_iter) BLIS_PRE_STRSM_SMALL_1N_4M(AlphaVal,b11,cs_b) @@ -22839,7 +25077,7 @@ BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB ymm3 = _mm256_setzero_ps(); ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_1nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_1nx3m(a01,b10,cs_b,p_lda,k_iter) BLIS_PRE_STRSM_SMALL_1N_3M(AlphaVal,b11,cs_b) @@ -22866,7 +25104,7 @@ BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB ymm3 = _mm256_setzero_ps(); ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_1nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_1nx2m(a01,b10,cs_b,p_lda,k_iter) BLIS_PRE_STRSM_SMALL_1N_2M(AlphaVal,b11,cs_b) @@ -22893,7 +25131,7 @@ BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB ymm3 = _mm256_setzero_ps(); ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_1nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_1nx1m(a01,b10,cs_b,p_lda,k_iter) BLIS_PRE_STRSM_SMALL_1N_1M(AlphaVal,b11,cs_b) @@ -29457,7 +31695,7 @@ BLIS_INLINE err_t bli_strsm_small_AltXB_AuXB dim_t p_lda = 8; // packed leading dimension if(transa) { - for(dim_t x =0;x < m-i+8;x+=p_lda) + for(dim_t x =0;x < m-i-8;x+=p_lda) { ymm0 = _mm256_loadu_ps((float const *)(a10)); ymm1 = _mm256_loadu_ps((float const *)(a10 + cs_a)); @@ -30375,7 +32613,7 @@ BLIS_INLINE err_t bli_strsm_small_AltXB_AuXB __m128 xmm6,xmm7,xmm8,xmm9; if(transa) { - for(dim_t x =0;x < m-i+4;x+=p_lda) + for(dim_t x =0;x < m-i-4;x+=p_lda) { xmm0 = _mm_loadu_ps((float const *)(a10)); xmm1 = _mm_loadu_ps((float const *)(a10 + cs_a)); @@ -31519,6 +33757,9 @@ BLIS_INLINE err_t bli_ztrsm_small_AutXB_AlXB __m128d xmm5, xmm4; + xmm4 = _mm_setzero_pd(); + xmm5 = _mm_setzero_pd(); + gint_t required_packing_A = 1; mem_t local_mem_buf_A_s = {0}; dcomplex *D_A_pack = NULL; @@ -32305,12 +34546,15 @@ BLIS_INLINE err_t bli_ztrsm_small_AutXB_AlXB _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm9); _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm10); - ymm0 = _mm256_loadu_pd((double const *) - (b11 + cs_b *0 + 2)); - ymm1 = _mm256_loadu_pd((double const *) - (b11 + cs_b *1 + 2)); - ymm2 = _mm256_loadu_pd((double const *) - (b11 + cs_b *2 + 2)); + xmm4 = _mm_loadu_pd((double const *) + (b11 + cs_b *0 + 2)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm4, 0); + xmm4 = _mm_loadu_pd((double const *) + (b11 + cs_b *1 + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm4, 0); + xmm4 = _mm_loadu_pd((double const *) + (b11 + cs_b *2 + 2)); + ymm2 = _mm256_insertf128_pd(ymm2, xmm4, 0); ymm14 = _mm256_permute_pd(ymm16, 0x5); ymm14 = _mm256_mul_pd(ymm14, ymm18); @@ -32335,11 +34579,11 @@ BLIS_INLINE err_t bli_ztrsm_small_AutXB_AlXB ymm13 = _mm256_sub_pd(ymm15,ymm13); _mm_storeu_pd((double *)(b11 + 2), - _mm256_extractf128_pd(ymm11,0)); + _mm256_castpd256_pd128(ymm11)); _mm_storeu_pd((double *)(b11 + cs_b * 1 + 2), - _mm256_extractf128_pd(ymm12,0)); + _mm256_castpd256_pd128(ymm12)); _mm_storeu_pd((double *)(b11 + cs_b * 2 + 2), - _mm256_extractf128_pd(ymm13,0)); + _mm256_castpd256_pd128(ymm13)); if(transa) ztrsm_AutXB_ref(a11, b11, m_rem, 3, @@ -32501,7 +34745,7 @@ BLIS_INLINE err_t bli_ztrsm_small_AutXB_AlXB if(2 == n_rem) { ///GEMM code begins/// - BLIS_ZTRSM_SMALL_GEMM_4mx2n(a10,b01,cs_b, + BLIS_ZTRSM_SMALL_GEMM_2mx2n(a10,b01,cs_b, p_lda,k_iter) BLIS_PRE_ZTRSM_SMALL_2M_2N(AlphaVal,b11,cs_b) @@ -32518,7 +34762,7 @@ BLIS_INLINE err_t bli_ztrsm_small_AutXB_AlXB else if(1 == n_rem) { ///GEMM code begins/// - BLIS_ZTRSM_SMALL_GEMM_4mx1n(a10,b01,cs_b, + BLIS_ZTRSM_SMALL_GEMM_2mx1n(a10,b01,cs_b, p_lda,k_iter) BLIS_PRE_ZTRSM_SMALL_2M_1N(AlphaVal,b11,cs_b) @@ -32541,35 +34785,33 @@ BLIS_INLINE err_t bli_ztrsm_small_AutXB_AlXB { dim_t p_lda = 2; // packed leading dimension if(transa) - { - dim_t x = 0; - for(x = 0; (x + 1) < i; x += p_lda) - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - _mm_storeu_pd((double *)(ptr_a10_dup), - _mm256_extractf128_pd(ymm0, 0)); - _mm_storeu_pd((double *)(ptr_a10_dup + - p_lda), _mm256_extractf128_pd(ymm0, 1)); - a10 += p_lda; - ptr_a10_dup += p_lda * p_lda; - } - for(; x < i; x += 1) - { - xmm4 = _mm_loadu_pd((double const *)(a10)); - _mm_storeu_pd((double *)(ptr_a10_dup), xmm4); - a10 += 1; - ptr_a10_dup += 1; - } + { + dim_t x = 0; + for(x = 0; (x + 1) < i; x += p_lda) + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + _mm_storeu_pd((double *)(ptr_a10_dup), + _mm256_castpd256_pd128(ymm0)); + _mm_storeu_pd((double *)(ptr_a10_dup + + p_lda), _mm256_extractf128_pd(ymm0, 1)); + a10 += p_lda; + ptr_a10_dup += p_lda * p_lda; + } + for(; x < i; x += 1) + { + xmm4 = _mm_loadu_pd((double const *)(a10)); + _mm_storeu_pd((double *)(ptr_a10_dup), xmm4); + a10 += 1; + ptr_a10_dup += 1; + } - } + } else { for(dim_t x=0;x 0; j -= d_nr) { @@ -33801,14 +36052,17 @@ BLIS_INLINE err_t bli_ztrsm_small_AltXB_AuXB BLIS_SET_YMM_REG_ZEROS ///GEMM code begins/// - BLIS_ZTRSM_SMALL_GEMM_2mx3n(a10,b01,cs_b,p_lda,k_iter) + BLIS_ZTRSM_SMALL_GEMM_1mx3n(a10,b01,cs_b,p_lda,k_iter) ///GEMM code ends/// ymm16 = _mm256_broadcast_pd((__m128d const *) (&AlphaVal)); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); + xmm4 = _mm_loadu_pd((double const *)(b11 + cs_b *0)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm4, 0); + xmm4 = _mm_loadu_pd((double const *)(b11 + cs_b *1)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm4, 0); + xmm4 = _mm_loadu_pd((double const *)(b11 + cs_b *2)); + ymm2 = _mm256_insertf128_pd(ymm2, xmm4, 0); ymm18 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); ymm14 = _mm256_permute_pd(ymm16, 0x5); @@ -33835,11 +36089,11 @@ BLIS_INLINE err_t bli_ztrsm_small_AltXB_AuXB ymm10 = _mm256_sub_pd(ymm15,ymm10); _mm_storeu_pd((double *)(b11), - _mm256_extractf128_pd(ymm8,0)); + _mm256_castpd256_pd128(ymm8)); _mm_storeu_pd((double *)(b11 + cs_b * 1), - _mm256_extractf128_pd(ymm9,0) ); + _mm256_castpd256_pd128(ymm9) ); _mm_storeu_pd((double *)(b11 + cs_b * 2), - _mm256_extractf128_pd(ymm10,0)); + _mm256_castpd256_pd128(ymm10)); if(transa) ztrsm_AltXB_ref(a11, b11, m_remainder, 3, @@ -33864,7 +36118,7 @@ BLIS_INLINE err_t bli_ztrsm_small_AltXB_AuXB { ///GEMM code begins/// - BLIS_ZTRSM_SMALL_GEMM_4mx2n(a10,b01,cs_b, + BLIS_ZTRSM_SMALL_GEMM_1mx2n(a10,b01,cs_b, p_lda,k_iter) BLIS_PRE_ZTRSM_SMALL_1M_2N(AlphaVal,b11,cs_b) @@ -33881,7 +36135,7 @@ BLIS_INLINE err_t bli_ztrsm_small_AltXB_AuXB else if(1 == n_remainder) { ///GEMM code begins/// - BLIS_ZTRSM_SMALL_GEMM_4mx1n(a10,b01,cs_b, + BLIS_ZTRSM_SMALL_GEMM_1mx1n(a10,b01,cs_b, p_lda,k_iter) BLIS_PRE_ZTRSM_SMALL_1M_1N(AlphaVal,b11,cs_b) @@ -33991,6 +36245,8 @@ BLIS_INLINE err_t bli_ztrsm_small_XAutB_XAlB __m128d xmm5; + xmm5 = _mm_setzero_pd(); + for(j = (n-d_nr); (j+1) > 0; j -= d_nr) //loop along 'N' direction { a01 = L + (j*rs_a) + (j+d_nr)*cs_a; @@ -34239,7 +36495,7 @@ BLIS_INLINE err_t bli_ztrsm_small_XAutB_XAlB where k_iter are zero */ - BLIS_ZTRSM_SMALL_GEMM_3nx4m(a01,b10,cs_b,p_lda,k_iter) + BLIS_ZTRSM_SMALL_GEMM_3nx3m(a01,b10,cs_b,p_lda,k_iter) /* Load b11 multiply with alpha @@ -34247,7 +36503,7 @@ BLIS_INLINE err_t bli_ztrsm_small_XAutB_XAlB and peform TRSM operation. */ - BLIS_PRE_ZTRSM_SMALL_3x4(AlphaVal,b11,cs_b) + BLIS_PRE_ZTRSM_SMALL_3x3(AlphaVal,b11,cs_b) ///implement TRSM/// /* Compute 3x3 TRSM block by using GEMM block output in @@ -34405,15 +36661,15 @@ BLIS_INLINE err_t bli_ztrsm_small_XAutB_XAlB #endif _mm256_storeu_pd((double *)b11, ymm3); _mm_storeu_pd((double *)(b11 + 2), - _mm256_extractf128_pd(ymm4,0)); + _mm256_castpd256_pd128(ymm4)); _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); _mm_storeu_pd((double *)(b11 + cs_b + 2), - _mm256_extractf128_pd(ymm6,0)); + _mm256_castpd256_pd128(ymm6)); _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); _mm_storeu_pd((double *)(b11 + cs_b*2 + 2), - _mm256_extractf128_pd(ymm8,0)); + _mm256_castpd256_pd128(ymm8)); m_remainder -=3; } else if(2 == m_remainder) @@ -34580,7 +36836,7 @@ BLIS_INLINE err_t bli_ztrsm_small_XAutB_XAlB where k_iter are zero */ - BLIS_ZTRSM_SMALL_GEMM_3nx2m(a01,b10,cs_b,p_lda,k_iter) + BLIS_ZTRSM_SMALL_GEMM_3nx1m(a01,b10,cs_b,p_lda,k_iter) /* Load b11 and multiply with alpha @@ -34588,7 +36844,7 @@ BLIS_INLINE err_t bli_ztrsm_small_XAutB_XAlB and peform TRSM operation. */ - BLIS_PRE_ZTRSM_SMALL_3x2(AlphaVal,b11,cs_b) + BLIS_PRE_ZTRSM_SMALL_3x1(AlphaVal,b11,cs_b) ///implement TRSM/// /* Compute 3x3 TRSM block by using GEMM block output @@ -34710,11 +36966,11 @@ BLIS_INLINE err_t bli_ztrsm_small_XAutB_XAlB BLIS_ZTRSM_MUL(ymm3) #endif _mm_storeu_pd((double *)b11, - _mm256_extractf128_pd(ymm3,0)); + _mm256_castpd256_pd128(ymm3)); _mm_storeu_pd((double *)(b11 + cs_b), - _mm256_extractf128_pd(ymm5,0)); + _mm256_castpd256_pd128(ymm5)); _mm_storeu_pd((double *)(b11 + cs_b*2), - _mm256_extractf128_pd(ymm7,0)); + _mm256_castpd256_pd128(ymm7)); m_remainder -=1; } } @@ -34757,11 +37013,11 @@ BLIS_INLINE err_t bli_ztrsm_small_XAutB_XAlB ymm5 = _mm256_permute2f128_pd(ymm1,ymm5,0x20); _mm_storeu_pd((double *)(ptr_a10_dup + 2), - _mm256_extractf128_pd(ymm3,0)); + _mm256_castpd256_pd128(ymm3)); _mm_storeu_pd((double *)(ptr_a10_dup + p_lda + 2), - _mm256_extractf128_pd(ymm4,0)); + _mm256_castpd256_pd128(ymm4)); _mm_storeu_pd((double *)(ptr_a10_dup + p_lda * 2 + 2), - _mm256_extractf128_pd(ymm5, 0)); + _mm256_castpd256_pd128(ymm5)); a01 += d_nr*cs_a; ptr_a10_dup += d_nr; } @@ -34913,10 +37169,12 @@ BLIS_INLINE err_t bli_ztrsm_small_XAutB_XAlB For first itteration there will be no GEMM operation where k_iter are zero */ - BLIS_ZTRSM_SMALL_GEMM_3nx4m(a01,b10,cs_b,p_lda,k_iter) + //BLIS_ZTRSM_SMALL_GEMM_3nx3m(a01,b10,cs_b,p_lda,k_iter) + BLIS_ZTRSM_SMALL_GEMM_2nx3m(a01,b10,cs_b,p_lda,k_iter) // Load b11 and multiply with alpha - BLIS_PRE_ZTRSM_SMALL_3x4(AlphaVal,b11,cs_b) + //BLIS_PRE_ZTRSM_SMALL_3x3(AlphaVal,b11,cs_b) + BLIS_PRE_ZTRSM_SMALL_2x3(AlphaVal,b11,cs_b) ////extract a00 ymm0 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); ymm15 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); @@ -34977,11 +37235,11 @@ BLIS_INLINE err_t bli_ztrsm_small_XAutB_XAlB #endif _mm256_storeu_pd((double *)b11, ymm3); _mm_storeu_pd((double *)(b11 + 2), - _mm256_extractf128_pd(ymm4,0)); + _mm256_castpd256_pd128(ymm4)); _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); _mm_storeu_pd((double *)(b11 + cs_b + 2), - _mm256_extractf128_pd(ymm6,0)); + _mm256_castpd256_pd128(ymm6)); m_remainder -=3; } if(2 == m_remainder) @@ -35000,10 +37258,10 @@ BLIS_INLINE err_t bli_ztrsm_small_XAutB_XAlB For first itteration there will be no GEMM operation where k_iter are zero */ - BLIS_ZTRSM_SMALL_GEMM_3nx2m(a01,b10,cs_b,p_lda,k_iter) + BLIS_ZTRSM_SMALL_GEMM_2nx2m(a01,b10,cs_b,p_lda,k_iter) // Load b11 and multiply with alpha - BLIS_PRE_ZTRSM_SMALL_3x2(AlphaVal,b11,cs_b) + BLIS_PRE_ZTRSM_SMALL_2x2(AlphaVal,b11,cs_b) ////extract a00 ymm0 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); ymm15 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); @@ -35071,10 +37329,10 @@ BLIS_INLINE err_t bli_ztrsm_small_XAutB_XAlB For first itteration there will be no GEMM operation where k_iter are zero */ - BLIS_ZTRSM_SMALL_GEMM_3nx2m(a01,b10,cs_b,p_lda,k_iter) + BLIS_ZTRSM_SMALL_GEMM_2nx1m(a01,b10,cs_b,p_lda,k_iter) // Load b11 and multiply with alpha - BLIS_PRE_ZTRSM_SMALL_3x2(AlphaVal,b11,cs_b) + BLIS_PRE_ZTRSM_SMALL_2x1(AlphaVal,b11,cs_b) ////extract a00 ymm0 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); ymm15 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); @@ -35123,9 +37381,9 @@ BLIS_INLINE err_t bli_ztrsm_small_XAutB_XAlB BLIS_ZTRSM_MUL(ymm3) #endif _mm_storeu_pd((double *)b11, - _mm256_extractf128_pd(ymm3,0)); + _mm256_castpd256_pd128(ymm3)); _mm_storeu_pd((double *)(b11 + cs_b), - _mm256_extractf128_pd(ymm5,0)); + _mm256_castpd256_pd128(ymm5)); m_remainder -=1; } n_remainder -= 2; @@ -35167,12 +37425,12 @@ BLIS_INLINE err_t bli_ztrsm_small_XAutB_XAlB ymm5 = _mm256_permute2f128_pd(ymm1,ymm5,0x20); _mm_storeu_pd((double *)(ptr_a10_dup + 2), - _mm256_extractf128_pd(ymm3,0)); + _mm256_castpd256_pd128(ymm3)); _mm_storeu_pd((double *)(ptr_a10_dup + p_lda + 2), - _mm256_extractf128_pd(ymm4,0)); + _mm256_castpd256_pd128(ymm4)); _mm_storeu_pd((double *) (ptr_a10_dup + p_lda * 2 + 2), - _mm256_extractf128_pd(ymm5, 0)); + _mm256_castpd256_pd128(ymm5)); a01 += d_nr*cs_a; ptr_a10_dup += d_nr; } @@ -35264,7 +37522,7 @@ BLIS_INLINE err_t bli_ztrsm_small_XAutB_XAlB BLIS_SET_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_ZTRSM_SMALL_GEMM_1nx4m(a01,b10,cs_b,p_lda,k_iter) + BLIS_ZTRSM_SMALL_GEMM_1nx3m(a01,b10,cs_b,p_lda,k_iter) BLIS_PRE_ZTRSM_SMALL_1x3(b11,cs_b,AlphaVal) ///implement TRSM/// @@ -35283,7 +37541,7 @@ BLIS_INLINE err_t bli_ztrsm_small_XAutB_XAlB _mm256_storeu_pd((double *)b11, ymm3); _mm_storeu_pd((double *)(b11 + 2), - _mm256_extractf128_pd(ymm4,0)); + _mm256_castpd256_pd128(ymm4)); m_remainder -=3; } @@ -35351,7 +37609,7 @@ BLIS_INLINE err_t bli_ztrsm_small_XAutB_XAlB BLIS_ZTRSM_MUL(ymm3) #endif _mm_storeu_pd((double *)b11, - _mm256_extractf128_pd(ymm3,0)); + _mm256_castpd256_pd128(ymm3)); m_remainder -=1; } n_remainder -= 1; @@ -35450,6 +37708,8 @@ BLIS_INLINE err_t bli_ztrsm_small_XAltB_XAuB __m128d xmm5; + xmm5 = _mm_setzero_pd(); + for(j = 0; (j+d_nr-1) < n; j += d_nr) //loop along 'N' direction { a01 = L + j*rs_a;//pointer to block of A to be used in GEMM @@ -35695,10 +37955,10 @@ BLIS_INLINE err_t bli_ztrsm_small_XAltB_XAuB BLIS_SET_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_ZTRSM_SMALL_GEMM_3nx4m(a01,b10,cs_b,p_lda,k_iter) + BLIS_ZTRSM_SMALL_GEMM_3nx3m(a01,b10,cs_b,p_lda,k_iter) // Load b11 of size 4x6 and multiply with alpha - BLIS_PRE_ZTRSM_SMALL_3x4(AlphaVal,b11,cs_b) + BLIS_PRE_ZTRSM_SMALL_3x3(AlphaVal,b11,cs_b) ///implement TRSM/// ////extract a00 @@ -35850,15 +38110,15 @@ BLIS_INLINE err_t bli_ztrsm_small_XAltB_XAuB _mm256_storeu_pd((double *)b11, ymm3); _mm_storeu_pd((double *)(b11 + 2), - _mm256_extractf128_pd(ymm4,0)); + _mm256_castpd256_pd128(ymm4)); _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); _mm_storeu_pd((double *)(b11 + cs_b + 2), - _mm256_extractf128_pd(ymm6,0)); + _mm256_castpd256_pd128(ymm6)); _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); _mm_storeu_pd((double *)(b11 + cs_b*2 + 2), - _mm256_extractf128_pd(ymm8,0)); + _mm256_castpd256_pd128(ymm8)); m_remainder -= 3; i += 3; @@ -36012,10 +38272,10 @@ BLIS_INLINE err_t bli_ztrsm_small_XAltB_XAuB BLIS_SET_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_ZTRSM_SMALL_GEMM_3nx2m(a01,b10,cs_b,p_lda,k_iter) + BLIS_ZTRSM_SMALL_GEMM_3nx1m(a01,b10,cs_b,p_lda,k_iter) // Load b11 of size 2x3 and multiply with alpha - BLIS_PRE_ZTRSM_SMALL_3x2(AlphaVal,b11,cs_b) + BLIS_PRE_ZTRSM_SMALL_3x1(AlphaVal,b11,cs_b) ///implement TRSM/// ////extract a00 @@ -36132,11 +38392,11 @@ BLIS_INLINE err_t bli_ztrsm_small_XAltB_XAuB _mm_storeu_pd((double *)b11, - _mm256_extractf128_pd(ymm3,0)); + _mm256_castpd256_pd128(ymm3)); _mm_storeu_pd((double *)(b11 + cs_b), - _mm256_extractf128_pd(ymm5,0)); + _mm256_castpd256_pd128(ymm5)); _mm_storeu_pd((double *)(b11 + cs_b*2), - _mm256_extractf128_pd(ymm7,0)); + _mm256_castpd256_pd128(ymm7)); m_remainder -= 1; i += 1; @@ -36184,11 +38444,11 @@ BLIS_INLINE err_t bli_ztrsm_small_XAltB_XAuB ymm5 = _mm256_permute2f128_pd(ymm1,ymm5,0x20); _mm_storeu_pd((double *)(ptr_a10_dup + 2), - _mm256_extractf128_pd(ymm3,0)); + _mm256_castpd256_pd128(ymm3)); _mm_storeu_pd((double *)(ptr_a10_dup + p_lda + 2), - _mm256_extractf128_pd(ymm4,0)); + _mm256_castpd256_pd128(ymm4)); _mm_storeu_pd((double *)(ptr_a10_dup + p_lda * 2 + 2), - _mm256_extractf128_pd(ymm5, 0)); + _mm256_castpd256_pd128(ymm5)); a01 += d_nr*cs_a; ptr_a10_dup += d_nr; } @@ -36338,10 +38598,10 @@ BLIS_INLINE err_t bli_ztrsm_small_XAltB_XAuB BLIS_SET_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_ZTRSM_SMALL_GEMM_3nx4m(a01,b10,cs_b,p_lda,k_iter) + BLIS_ZTRSM_SMALL_GEMM_2nx3m(a01,b10,cs_b,p_lda,k_iter) // Load b11 of size 4x6 and multiply with alpha - BLIS_PRE_ZTRSM_SMALL_3x4(AlphaVal,b11,cs_b) + BLIS_PRE_ZTRSM_SMALL_2x3(AlphaVal,b11,cs_b) ///implement TRSM/// ////extract a00 @@ -36405,11 +38665,11 @@ BLIS_INLINE err_t bli_ztrsm_small_XAltB_XAuB _mm256_storeu_pd((double *)b11, ymm3); _mm_storeu_pd((double *)(b11 + 2), - _mm256_extractf128_pd(ymm4,0)); + _mm256_castpd256_pd128(ymm4)); _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); _mm_storeu_pd((double *)(b11 + cs_b + 2), - _mm256_extractf128_pd(ymm6,0)); + _mm256_castpd256_pd128(ymm6)); m_remainder -= 3; i += 3; } @@ -36495,10 +38755,10 @@ BLIS_INLINE err_t bli_ztrsm_small_XAltB_XAuB BLIS_SET_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_ZTRSM_SMALL_GEMM_2nx2m(a01,b10,cs_b,p_lda,k_iter) + BLIS_ZTRSM_SMALL_GEMM_2nx1m(a01,b10,cs_b,p_lda,k_iter) // Load b11 of size 4x6 and multiply with alpha - BLIS_PRE_ZTRSM_SMALL_2x2(AlphaVal,b11,cs_b) + BLIS_PRE_ZTRSM_SMALL_2x1(AlphaVal,b11,cs_b) ///implement TRSM/// ////extract a00 @@ -36548,9 +38808,9 @@ BLIS_INLINE err_t bli_ztrsm_small_XAltB_XAuB BLIS_ZTRSM_MUL(ymm5) #endif _mm_storeu_pd((double *)b11, - _mm256_extractf128_pd(ymm3,0)); + _mm256_castpd256_pd128(ymm3)); _mm_storeu_pd((double *)(b11 + cs_b), - _mm256_extractf128_pd(ymm5,0)); + _mm256_castpd256_pd128(ymm5)); m_remainder -= 1; i += 1; } @@ -36595,11 +38855,11 @@ BLIS_INLINE err_t bli_ztrsm_small_XAltB_XAuB ymm5 = _mm256_permute2f128_pd(ymm1,ymm5,0x20); _mm_storeu_pd((double *)(ptr_a10_dup + 2), - _mm256_extractf128_pd(ymm3,0)); + _mm256_castpd256_pd128(ymm3)); _mm_storeu_pd((double *)(ptr_a10_dup + p_lda + 2), - _mm256_extractf128_pd(ymm4,0)); + _mm256_castpd256_pd128(ymm4)); _mm_storeu_pd((double *)(ptr_a10_dup + p_lda * 2 + 2), - _mm256_extractf128_pd(ymm5, 0)); + _mm256_castpd256_pd128(ymm5)); a01 += d_nr*cs_a; ptr_a10_dup += d_nr; } @@ -36691,7 +38951,7 @@ BLIS_INLINE err_t bli_ztrsm_small_XAltB_XAuB BLIS_SET_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_ZTRSM_SMALL_GEMM_1nx4m(a01,b10,cs_b,p_lda,k_iter) + BLIS_ZTRSM_SMALL_GEMM_1nx3m(a01,b10,cs_b,p_lda,k_iter) BLIS_PRE_ZTRSM_SMALL_1x3(b11,cs_b,AlphaVal) ///implement TRSM/// @@ -36710,7 +38970,7 @@ BLIS_INLINE err_t bli_ztrsm_small_XAltB_XAuB _mm256_storeu_pd((double *)b11, ymm3); _mm_storeu_pd((double *)(b11 + 2), - _mm256_extractf128_pd(ymm4,0)); + _mm256_castpd256_pd128(ymm4)); m_remainder -= 3; i += 3; } @@ -36786,7 +39046,7 @@ BLIS_INLINE err_t bli_ztrsm_small_XAltB_XAuB BLIS_ZTRSM_MUL(ymm3) #endif _mm_storeu_pd((double *)b11, - _mm256_extractf128_pd(ymm3,0)); + _mm256_castpd256_pd128(ymm3)); m_remainder -= 1; i += 1; } @@ -37280,7 +39540,7 @@ BLIS_INLINE void ctrsm_small_pack_diag_element ymm16 = _mm256_broadcast_ps(( __m128 const *)(&AlphaVal));\ ymm16 = _mm256_permute_ps(ymm16, 0x44);\ \ - xmm0 = _mm_loadu_ps((float const *)(b11));\ + xmm0 = _mm_loadl_pi(xmm0,(__m64 const *)(b11));\ ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0);\ /*in register transpose * ymm0,ymm1,ymm2 holds @@ -37310,7 +39570,8 @@ BLIS_INLINE void ctrsm_small_pack_diag_element ymm16 = _mm256_broadcast_ps(( __m128 const *)(&AlphaVal));\ ymm16 = _mm256_permute_ps(ymm16, 0x44);\ \ - ymm0 = _mm256_loadu_ps((float const *)(b11));\ + xmm0 = _mm_loadu_ps((float const *)(b11));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0);\ /*in register transpose * ymm0,ymm1,ymm2 holds * two dcomplex elements of b11 cols*/\ @@ -37339,7 +39600,7 @@ BLIS_INLINE void ctrsm_small_pack_diag_element ymm16 = _mm256_permute_ps(ymm16, 0x44);\ \ xmm0 = _mm_loadu_ps((float const *)(b11));\ - xmm1 = _mm_loadu_ps((float const *)(b11 + 2));\ + xmm1 = _mm_loadl_pi(xmm1,(__m64 const *)(b11 + 2));\ ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0);\ ymm0 = _mm256_insertf128_ps(ymm0, xmm1, 1);\ /*in register transpose @@ -37372,9 +39633,9 @@ BLIS_INLINE void ctrsm_small_pack_diag_element ymm16 = _mm256_broadcast_ps(( __m128 const *)(&AlphaVal));\ ymm16 = _mm256_permute_ps(ymm16, 0x44);\ \ - xmm0 = _mm_loadu_ps((float const *)(b11));\ + xmm0 = _mm_loadl_pi(xmm0,(__m64 const *)(b11));\ ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0);\ - xmm1 = _mm_loadu_ps((float const *)(b11 + cs_b * 1));\ + xmm1 = _mm_loadl_pi(xmm1,(__m64 const *)(b11 + cs_b * 1));\ ymm1 = _mm256_insertf128_ps(ymm1, xmm1, 0);\ /*in register transpose * ymm0,ymm1,ymm2 holds @@ -37412,8 +39673,10 @@ BLIS_INLINE void ctrsm_small_pack_diag_element ymm16 = _mm256_broadcast_ps(( __m128 const *)(&AlphaVal));\ ymm16 = _mm256_permute_ps(ymm16, 0x44);\ \ - ymm0 = _mm256_loadu_ps((float const *)(b11));\ - ymm1 = _mm256_loadu_ps((float const *)(b11 + cs_b *1));\ + xmm0 = _mm_loadu_ps((float const *)(b11));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0);\ + xmm1 = _mm_loadu_ps((float const *)(b11 + cs_b * 1));\ + ymm1 = _mm256_insertf128_ps(ymm1, xmm1, 0);\ /*in register transpose * ymm0,ymm1,ymm2 holds * two dcomplex elements of b11 cols*/\ @@ -37558,6 +39821,132 @@ BLIS_INLINE void ctrsm_small_pack_diag_element }\ } +/** + * Multiplies Alpha with one scomplex + * element of three column. + */ +#define BLIS_PRE_CTRSM_SMALL_3x1(AlphaVal, b11,cs_b){\ + ymm16 = _mm256_broadcast_ps(( __m128 const *)(&AlphaVal));\ + ymm16 = _mm256_permute_ps(ymm16, 0x44);\ + \ + xmm1 = _mm_loadl_pi(xmm1, (__m64 const *)(b11));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm1, 0);\ + \ + ymm17 = _mm256_shuffle_ps(ymm16, ymm16, 0x11);\ + ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0);\ + ymm19 = _mm256_shuffle_ps(ymm0, ymm0,0xF5);\ + ymm19 = _mm256_mul_ps(ymm19, ymm17);\ + ymm19 = _mm256_fmaddsub_ps(ymm18, ymm16, ymm19);\ + ymm8 = _mm256_sub_ps(ymm19, ymm8);\ + \ + xmm1 = _mm_loadl_pi(xmm1, (__m64 const *)(b11 + cs_b));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm1, 0);\ + \ + ymm17 = _mm256_shuffle_ps(ymm16, ymm16, 0x11);\ + ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0);\ + ymm19 = _mm256_shuffle_ps(ymm0, ymm0,0xF5);\ + ymm19 = _mm256_mul_ps(ymm19, ymm17);\ + ymm19 = _mm256_fmaddsub_ps(ymm18, ymm16, ymm19);\ + ymm10 = _mm256_sub_ps(ymm19, ymm10);\ + \ + xmm1 = _mm_loadl_pi(xmm1, (__m64 const *)(b11 + cs_b*2));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm1, 0);\ + \ + ymm17 = _mm256_shuffle_ps(ymm16, ymm16, 0x11);\ + ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0);\ + ymm19 = _mm256_shuffle_ps(ymm0, ymm0,0xF5);\ + ymm19 = _mm256_mul_ps(ymm19, ymm17);\ + ymm19 = _mm256_fmaddsub_ps(ymm18, ymm16, ymm19);\ + ymm12 = _mm256_sub_ps(ymm19, ymm12);\ + \ +} + +/** + * Multiplies Alpha with two scomplex + * element of three column. + */ +#define BLIS_PRE_CTRSM_SMALL_3x2(AlphaVal, b11,cs_b){\ + ymm16 = _mm256_broadcast_ps(( __m128 const *)(&AlphaVal));\ + ymm16 = _mm256_permute_ps(ymm16, 0x44);\ + \ + xmm0 = _mm_loadu_ps((float const *)(b11));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0);\ + \ + ymm17 = _mm256_shuffle_ps(ymm16, ymm16, 0x11);\ + ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0);\ + ymm19 = _mm256_shuffle_ps(ymm0, ymm0,0xF5);\ + ymm19 = _mm256_mul_ps(ymm19, ymm17);\ + ymm19 = _mm256_fmaddsub_ps(ymm18, ymm16, ymm19);\ + ymm8 = _mm256_sub_ps(ymm19, ymm8);\ + \ + xmm0 = _mm_loadu_ps((float const *)(b11 + cs_b));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0);\ + \ + ymm17 = _mm256_shuffle_ps(ymm16, ymm16, 0x11);\ + ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0);\ + ymm19 = _mm256_shuffle_ps(ymm0, ymm0,0xF5);\ + ymm19 = _mm256_mul_ps(ymm19, ymm17);\ + ymm19 = _mm256_fmaddsub_ps(ymm18, ymm16, ymm19);\ + ymm10 = _mm256_sub_ps(ymm19, ymm10);\ + \ + xmm0 = _mm_loadu_ps((float const *)(b11 + cs_b*2));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0);\ + \ + ymm17 = _mm256_shuffle_ps(ymm16, ymm16, 0x11);\ + ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0);\ + ymm19 = _mm256_shuffle_ps(ymm0, ymm0,0xF5);\ + ymm19 = _mm256_mul_ps(ymm19, ymm17);\ + ymm19 = _mm256_fmaddsub_ps(ymm18, ymm16, ymm19);\ + ymm12 = _mm256_sub_ps(ymm19, ymm12);\ + \ +} + +/** + * Multiplies Alpha with three scomplex + * element of three column. + */ +#define BLIS_PRE_CTRSM_SMALL_3x3(AlphaVal, b11,cs_b){\ + ymm16 = _mm256_broadcast_ps(( __m128 const *)(&AlphaVal));\ + ymm16 = _mm256_permute_ps(ymm16, 0x44);\ + \ + xmm0 = _mm_loadu_ps((float const *)(b11));\ + xmm1 = _mm_loadl_pi(xmm1, (__m64 const *)(b11 + 2));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0);\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm1, 1);\ + \ + ymm17 = _mm256_shuffle_ps(ymm16, ymm16, 0x11);\ + ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0);\ + ymm19 = _mm256_shuffle_ps(ymm0, ymm0,0xF5);\ + ymm19 = _mm256_mul_ps(ymm19, ymm17);\ + ymm19 = _mm256_fmaddsub_ps(ymm18, ymm16, ymm19);\ + ymm8 = _mm256_sub_ps(ymm19, ymm8);\ + \ + xmm0 = _mm_loadu_ps((float const *)(b11 + cs_b));\ + xmm1 = _mm_loadl_pi(xmm1, (__m64 const *)(b11 + cs_b + 2));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0);\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm1, 1);\ + \ + ymm17 = _mm256_shuffle_ps(ymm16, ymm16, 0x11);\ + ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0);\ + ymm19 = _mm256_shuffle_ps(ymm0, ymm0,0xF5);\ + ymm19 = _mm256_mul_ps(ymm19, ymm17);\ + ymm19 = _mm256_fmaddsub_ps(ymm18, ymm16, ymm19);\ + ymm10 = _mm256_sub_ps(ymm19, ymm10);\ + \ + xmm0 = _mm_loadu_ps((float const *)(b11 + cs_b*2));\ + xmm1 = _mm_loadl_pi(xmm1, (__m64 const *)(b11 + cs_b*2 + 2));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0);\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm1, 1);\ + \ + ymm17 = _mm256_shuffle_ps(ymm16, ymm16, 0x11);\ + ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0);\ + ymm19 = _mm256_shuffle_ps(ymm0, ymm0,0xF5);\ + ymm19 = _mm256_mul_ps(ymm19, ymm17);\ + ymm19 = _mm256_fmaddsub_ps(ymm18, ymm16, ymm19);\ + ymm12 = _mm256_sub_ps(ymm19, ymm12);\ + \ +} + /** * Multiplies Alpha with four scomplex * element of three column. @@ -37645,41 +40034,520 @@ BLIS_INLINE void ctrsm_small_pack_diag_element ymm19 = _mm256_fmaddsub_ps(ymm18, ymm16, ymm19);\ ymm12 = _mm256_sub_ps(ymm19, ymm12);\ \ - ymm18 = _mm256_shuffle_ps(ymm1, ymm1, 0xA0);\ - ymm19 = _mm256_shuffle_ps(ymm1, ymm1,0xF5);\ - ymm19 = _mm256_mul_ps(ymm19, ymm17);\ - ymm19 = _mm256_fmaddsub_ps(ymm18, ymm16, ymm19);\ - ymm13 = _mm256_sub_ps(ymm19, ymm13);\ + ymm18 = _mm256_shuffle_ps(ymm1, ymm1, 0xA0);\ + ymm19 = _mm256_shuffle_ps(ymm1, ymm1,0xF5);\ + ymm19 = _mm256_mul_ps(ymm19, ymm17);\ + ymm19 = _mm256_fmaddsub_ps(ymm18, ymm16, ymm19);\ + ymm13 = _mm256_sub_ps(ymm19, ymm13);\ +} + + +/** + * Performs GEMM operation. + * Four elements of column in ymm0 + * ymm1, ymm2 holds respective broadcasted element. + */ +#define BLIS_CTRSM_SMALL_GEMM_2nx4m(a01,b10,cs_b,p_lda,k_iter){\ + float *tptr = (float *)a01;\ + if(conjtransa) {\ + ymm18 = _mm256_setr_ps(-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0);\ + for(k = 0; k< k_iter; k++) \ + { \ + ymm0 = _mm256_loadu_ps((float const *)(b10)); \ + \ + ymm2 = _mm256_broadcast_ss(tptr + p_lda * 0 + 0);\ + ymm3 = _mm256_broadcast_ss(tptr + p_lda * 0 + 1);\ + ymm3 = _mm256_mul_ps(ymm3, ymm18);\ + \ + _mm_prefetch((char*)( b10 + 2*cs_b), _MM_HINT_T0); \ + ymm8 = _mm256_fmadd_ps(ymm0, ymm2, ymm8);\ + ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);\ + \ + ymm2 = _mm256_broadcast_ss(tptr + p_lda * 1 * 2+ 0);\ + ymm3 = _mm256_broadcast_ss(tptr + p_lda * 1 * 2 + 1);\ + \ + ymm3 = _mm256_mul_ps(ymm3, ymm18);\ + ymm10 = _mm256_fmadd_ps(ymm0, ymm2, ymm10);\ + ymm6 = _mm256_fmadd_ps(ymm0, ymm3, ymm6);\ + \ + tptr += 2;\ + b10 += cs_b;\ + }\ + }\ + else {\ + for(k = 0; k< k_iter; k++) \ + { \ + ymm0 = _mm256_loadu_ps((float const *)(b10)); \ + \ + ymm2 = _mm256_broadcast_ss(tptr + p_lda * 0 + 0);\ + ymm3 = _mm256_broadcast_ss(tptr + p_lda * 0 + 1);\ + \ + _mm_prefetch((char*)( b10 + 2*cs_b), _MM_HINT_T0); \ + ymm8 = _mm256_fmadd_ps(ymm0, ymm2, ymm8);\ + ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);\ + \ + ymm2 = _mm256_broadcast_ss(tptr + p_lda * 1 * 2+ 0);\ + ymm3 = _mm256_broadcast_ss(tptr + p_lda * 1 * 2 + 1);\ + \ + ymm10 = _mm256_fmadd_ps(ymm0, ymm2, ymm10);\ + ymm6 = _mm256_fmadd_ps(ymm0, ymm3, ymm6);\ + \ + tptr += 2;\ + b10 += cs_b;\ + }\ + }\ + ymm4 = _mm256_permute_ps(ymm4, 0xb1);\ + ymm6 = _mm256_permute_ps(ymm6, 0xb1);\ + \ + ymm8 = _mm256_addsub_ps(ymm8, ymm4);\ + ymm10 = _mm256_addsub_ps(ymm10, ymm6);\ +} + +#define BLIS_CTRSM_SMALL_GEMM_2nx3m(a01,b10,cs_b,p_lda,k_iter){\ + float *tptr = (float *)a01;\ + if(conjtransa) {\ + ymm18 = _mm256_setr_ps(-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0);\ + for(k = 0; k< k_iter; k++) \ + { \ + xmm5 = _mm_loadu_ps((float const *)(b10));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ + xmm5 = _mm_loadl_pi(xmm5,(__m64 *)(b10 + 2));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 1);\ + \ + ymm2 = _mm256_broadcast_ss(tptr + p_lda * 0 + 0);\ + ymm3 = _mm256_broadcast_ss(tptr + p_lda * 0 + 1);\ + ymm3 = _mm256_mul_ps(ymm3, ymm18);\ + \ + _mm_prefetch((char*)( b10 + 2*cs_b), _MM_HINT_T0); \ + ymm8 = _mm256_fmadd_ps(ymm0, ymm2, ymm8);\ + ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);\ + \ + ymm2 = _mm256_broadcast_ss(tptr + p_lda * 1 * 2+ 0);\ + ymm3 = _mm256_broadcast_ss(tptr + p_lda * 1 * 2 + 1);\ + \ + ymm3 = _mm256_mul_ps(ymm3, ymm18);\ + ymm10 = _mm256_fmadd_ps(ymm0, ymm2, ymm10);\ + ymm6 = _mm256_fmadd_ps(ymm0, ymm3, ymm6);\ + \ + tptr += 2;\ + b10 += cs_b;\ + }\ + }\ + else {\ + for(k = 0; k< k_iter; k++) \ + { \ + xmm5 = _mm_loadu_ps((float const *)(b10));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ + xmm5 = _mm_loadl_pi(xmm5,(__m64 *)(b10 + 2));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 1);\ + \ + ymm2 = _mm256_broadcast_ss(tptr + p_lda * 0 + 0);\ + ymm3 = _mm256_broadcast_ss(tptr + p_lda * 0 + 1);\ + \ + _mm_prefetch((char*)( b10 + 2*cs_b), _MM_HINT_T0); \ + ymm8 = _mm256_fmadd_ps(ymm0, ymm2, ymm8);\ + ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);\ + \ + ymm2 = _mm256_broadcast_ss(tptr + p_lda * 1 * 2+ 0);\ + ymm3 = _mm256_broadcast_ss(tptr + p_lda * 1 * 2 + 1);\ + \ + ymm10 = _mm256_fmadd_ps(ymm0, ymm2, ymm10);\ + ymm6 = _mm256_fmadd_ps(ymm0, ymm3, ymm6);\ + \ + tptr += 2;\ + b10 += cs_b;\ + }\ + }\ + ymm4 = _mm256_permute_ps(ymm4, 0xb1);\ + ymm6 = _mm256_permute_ps(ymm6, 0xb1);\ + \ + ymm8 = _mm256_addsub_ps(ymm8, ymm4);\ + ymm10 = _mm256_addsub_ps(ymm10, ymm6);\ +} + +#define BLIS_CTRSM_SMALL_GEMM_2nx2m(a01,b10,cs_b,p_lda,k_iter){\ + float *tptr = (float *)a01;\ + if(conjtransa) {\ + ymm18 = _mm256_setr_ps(-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0);\ + for(k = 0; k< k_iter; k++) \ + { \ + xmm5 = _mm_loadu_ps((float const *)(b10));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ + \ + ymm2 = _mm256_broadcast_ss(tptr + p_lda * 0 + 0);\ + ymm3 = _mm256_broadcast_ss(tptr + p_lda * 0 + 1);\ + ymm3 = _mm256_mul_ps(ymm3, ymm18);\ + \ + _mm_prefetch((char*)( b10 + 2*cs_b), _MM_HINT_T0); \ + ymm8 = _mm256_fmadd_ps(ymm0, ymm2, ymm8);\ + ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);\ + \ + ymm2 = _mm256_broadcast_ss(tptr + p_lda * 1 * 2+ 0);\ + ymm3 = _mm256_broadcast_ss(tptr + p_lda * 1 * 2 + 1);\ + \ + ymm3 = _mm256_mul_ps(ymm3, ymm18);\ + ymm10 = _mm256_fmadd_ps(ymm0, ymm2, ymm10);\ + ymm6 = _mm256_fmadd_ps(ymm0, ymm3, ymm6);\ + \ + tptr += 2;\ + b10 += cs_b;\ + }\ + }\ + else {\ + for(k = 0; k< k_iter; k++) \ + { \ + xmm5 = _mm_loadu_ps((float const *)(b10));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ + \ + ymm2 = _mm256_broadcast_ss(tptr + p_lda * 0 + 0);\ + ymm3 = _mm256_broadcast_ss(tptr + p_lda * 0 + 1);\ + \ + _mm_prefetch((char*)( b10 + 2*cs_b), _MM_HINT_T0); \ + ymm8 = _mm256_fmadd_ps(ymm0, ymm2, ymm8);\ + ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);\ + \ + ymm2 = _mm256_broadcast_ss(tptr + p_lda * 1 * 2+ 0);\ + ymm3 = _mm256_broadcast_ss(tptr + p_lda * 1 * 2 + 1);\ + \ + ymm10 = _mm256_fmadd_ps(ymm0, ymm2, ymm10);\ + ymm6 = _mm256_fmadd_ps(ymm0, ymm3, ymm6);\ + \ + tptr += 2;\ + b10 += cs_b;\ + }\ + }\ + ymm4 = _mm256_permute_ps(ymm4, 0xb1);\ + ymm6 = _mm256_permute_ps(ymm6, 0xb1);\ + \ + ymm8 = _mm256_addsub_ps(ymm8, ymm4);\ + ymm10 = _mm256_addsub_ps(ymm10, ymm6);\ +} + +#define BLIS_CTRSM_SMALL_GEMM_2nx1m(a01,b10,cs_b,p_lda,k_iter){\ + float *tptr = (float *)a01;\ + if(conjtransa) {\ + ymm18 = _mm256_setr_ps(-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0);\ + for(k = 0; k< k_iter; k++) \ + { \ + ymm0 = _mm256_broadcast_ps(( __m128 const *)(b10));\ + ymm0 = _mm256_permute_ps(ymm0, 0x44);\ + \ + ymm2 = _mm256_broadcast_ss(tptr + p_lda * 0 + 0);\ + ymm3 = _mm256_broadcast_ss(tptr + p_lda * 0 + 1);\ + ymm3 = _mm256_mul_ps(ymm3, ymm18);\ + \ + _mm_prefetch((char*)( b10 + 2*cs_b), _MM_HINT_T0); \ + ymm8 = _mm256_fmadd_ps(ymm0, ymm2, ymm8);\ + ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);\ + \ + ymm2 = _mm256_broadcast_ss(tptr + p_lda * 1 * 2+ 0);\ + ymm3 = _mm256_broadcast_ss(tptr + p_lda * 1 * 2 + 1);\ + \ + ymm3 = _mm256_mul_ps(ymm3, ymm18);\ + ymm10 = _mm256_fmadd_ps(ymm0, ymm2, ymm10);\ + ymm6 = _mm256_fmadd_ps(ymm0, ymm3, ymm6);\ + \ + tptr += 2;\ + b10 += cs_b;\ + }\ + }\ + else {\ + for(k = 0; k< k_iter; k++) \ + { \ + ymm0 = _mm256_broadcast_ps(( __m128 const *)(b10 + 2));\ + ymm0 = _mm256_permute_ps(ymm0, 0x44);\ + xmm5 = _mm_loadu_ps((float const *)(b10));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ + \ + ymm2 = _mm256_broadcast_ss(tptr + p_lda * 0 + 0);\ + ymm3 = _mm256_broadcast_ss(tptr + p_lda * 0 + 1);\ + \ + _mm_prefetch((char*)( b10 + 2*cs_b), _MM_HINT_T0); \ + ymm8 = _mm256_fmadd_ps(ymm0, ymm2, ymm8);\ + ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);\ + \ + ymm2 = _mm256_broadcast_ss(tptr + p_lda * 1 * 2+ 0);\ + ymm3 = _mm256_broadcast_ss(tptr + p_lda * 1 * 2 + 1);\ + \ + ymm10 = _mm256_fmadd_ps(ymm0, ymm2, ymm10);\ + ymm6 = _mm256_fmadd_ps(ymm0, ymm3, ymm6);\ + \ + tptr += 2;\ + b10 += cs_b;\ + }\ + }\ + ymm4 = _mm256_permute_ps(ymm4, 0xb1);\ + ymm6 = _mm256_permute_ps(ymm6, 0xb1);\ + \ + ymm8 = _mm256_addsub_ps(ymm8, ymm4);\ + ymm10 = _mm256_addsub_ps(ymm10, ymm6);\ +} + +/** + * Performs GEMM operation. + * Four elements of column in ymm0 + * ymm1 holds respective broadcasted element. + */ +#define BLIS_CTRSM_SMALL_GEMM_1nx4m(a01,b10,cs_b,p_lda,k_iter){\ + float *tptr = (float *)a01;\ + if(conjtransa) {\ + ymm18 = _mm256_setr_ps(-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0);\ + for(k = 0; k< k_iter; k++) \ + { \ + ymm0 = _mm256_loadu_ps((float const *)(b10)); \ + \ + ymm2 = _mm256_broadcast_ss(tptr + p_lda * 0 + 0);\ + ymm3 = _mm256_broadcast_ss(tptr + p_lda * 0 + 1);\ + ymm3 = _mm256_mul_ps(ymm3, ymm18);\ + \ + _mm_prefetch((char*)( b10 + 4*cs_b), _MM_HINT_T0); \ + ymm8 = _mm256_fmadd_ps(ymm0, ymm2, ymm8);\ + ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);\ + \ + tptr += 2;\ + b10 += cs_b;\ + }\ + }\ + else {\ + for(k = 0; k< k_iter; k++) \ + { \ + ymm0 = _mm256_loadu_ps((float const *)(b10)); \ + \ + ymm2 = _mm256_broadcast_ss(tptr + p_lda * 0 + 0);\ + ymm3 = _mm256_broadcast_ss(tptr + p_lda * 0 + 1);\ + \ + _mm_prefetch((char*)( b10 + 4*cs_b), _MM_HINT_T0); \ + ymm8 = _mm256_fmadd_ps(ymm0, ymm2, ymm8);\ + ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);\ + \ + tptr += 2;\ + b10 += cs_b;\ + }\ + }\ + ymm4 = _mm256_permute_ps(ymm4, 0xb1);\ + \ + ymm8 = _mm256_addsub_ps(ymm8, ymm4);\ +} + +#define BLIS_CTRSM_SMALL_GEMM_1nx3m(a01,b10,cs_b,p_lda,k_iter){\ + float *tptr = (float *)a01;\ + if(conjtransa) {\ + ymm18 = _mm256_setr_ps(-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0);\ + for(k = 0; k< k_iter; k++) \ + { \ + xmm5 = _mm_loadu_ps((float const *)(b10));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ + xmm5 = _mm_loadl_pi(xmm5,(__m64 *)(b10 + 2));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 1);\ + \ + ymm2 = _mm256_broadcast_ss(tptr + p_lda * 0 + 0);\ + ymm3 = _mm256_broadcast_ss(tptr + p_lda * 0 + 1);\ + ymm3 = _mm256_mul_ps(ymm3, ymm18);\ + \ + _mm_prefetch((char*)( b10 + 4*cs_b), _MM_HINT_T0); \ + ymm8 = _mm256_fmadd_ps(ymm0, ymm2, ymm8);\ + ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);\ + \ + tptr += 2;\ + b10 += cs_b;\ + }\ + }\ + else {\ + for(k = 0; k< k_iter; k++) \ + { \ + xmm5 = _mm_loadu_ps((float const *)(b10));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ + xmm5 = _mm_loadl_pi(xmm5,(__m64 *)(b10 + 2));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 1);\ + \ + ymm2 = _mm256_broadcast_ss(tptr + p_lda * 0 + 0);\ + ymm3 = _mm256_broadcast_ss(tptr + p_lda * 0 + 1);\ + \ + _mm_prefetch((char*)( b10 + 4*cs_b), _MM_HINT_T0); \ + ymm8 = _mm256_fmadd_ps(ymm0, ymm2, ymm8);\ + ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);\ + \ + tptr += 2;\ + b10 += cs_b;\ + }\ + }\ + ymm4 = _mm256_permute_ps(ymm4, 0xb1);\ + \ + ymm8 = _mm256_addsub_ps(ymm8, ymm4);\ +} + +#define BLIS_CTRSM_SMALL_GEMM_1nx2m(a01,b10,cs_b,p_lda,k_iter){\ + float *tptr = (float *)a01;\ + if(conjtransa) {\ + ymm18 = _mm256_setr_ps(-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0);\ + for(k = 0; k< k_iter; k++) \ + { \ + xmm5 = _mm_loadu_ps((float const *)(b10));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ + \ + ymm2 = _mm256_broadcast_ss(tptr + p_lda * 0 + 0);\ + ymm3 = _mm256_broadcast_ss(tptr + p_lda * 0 + 1);\ + ymm3 = _mm256_mul_ps(ymm3, ymm18);\ + \ + _mm_prefetch((char*)( b10 + 4*cs_b), _MM_HINT_T0); \ + ymm8 = _mm256_fmadd_ps(ymm0, ymm2, ymm8);\ + ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);\ + \ + tptr += 2;\ + b10 += cs_b;\ + }\ + }\ + else {\ + for(k = 0; k< k_iter; k++) \ + { \ + xmm5 = _mm_loadu_ps((float const *)(b10));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ + \ + ymm2 = _mm256_broadcast_ss(tptr + p_lda * 0 + 0);\ + ymm3 = _mm256_broadcast_ss(tptr + p_lda * 0 + 1);\ + \ + _mm_prefetch((char*)( b10 + 4*cs_b), _MM_HINT_T0); \ + ymm8 = _mm256_fmadd_ps(ymm0, ymm2, ymm8);\ + ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);\ + \ + tptr += 2;\ + b10 += cs_b;\ + }\ + }\ + ymm4 = _mm256_permute_ps(ymm4, 0xb1);\ + \ + ymm8 = _mm256_addsub_ps(ymm8, ymm4);\ +} + +#define BLIS_CTRSM_SMALL_GEMM_1nx1m(a01,b10,cs_b,p_lda,k_iter){\ + float *tptr = (float *)a01;\ + if(conjtransa) {\ + ymm18 = _mm256_setr_ps(-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0);\ + for(k = 0; k< k_iter; k++) \ + { \ + xmm5 = _mm_loadl_pi(xmm5,(__m64 *)(b10));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ + \ + ymm2 = _mm256_broadcast_ss(tptr + p_lda * 0 + 0);\ + ymm3 = _mm256_broadcast_ss(tptr + p_lda * 0 + 1);\ + ymm3 = _mm256_mul_ps(ymm3, ymm18);\ + \ + _mm_prefetch((char*)( b10 + 4*cs_b), _MM_HINT_T0); \ + ymm8 = _mm256_fmadd_ps(ymm0, ymm2, ymm8);\ + ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);\ + \ + tptr += 2;\ + b10 += cs_b;\ + }\ + }\ + else {\ + for(k = 0; k< k_iter; k++) \ + { \ + xmm5 = _mm_loadl_pi(xmm5,(__m64 *)(b10));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ + \ + ymm2 = _mm256_broadcast_ss(tptr + p_lda * 0 + 0);\ + ymm3 = _mm256_broadcast_ss(tptr + p_lda * 0 + 1);\ + \ + _mm_prefetch((char*)( b10 + 4*cs_b), _MM_HINT_T0); \ + ymm8 = _mm256_fmadd_ps(ymm0, ymm2, ymm8);\ + ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);\ + \ + tptr += 2;\ + b10 += cs_b;\ + }\ + }\ + ymm4 = _mm256_permute_ps(ymm4, 0xb1);\ + \ + ymm8 = _mm256_addsub_ps(ymm8, ymm4);\ +} + +/** + * Performs GEMM operation. + * Eight elements of column in ymm0, ymm1 + * ymm1 holds respective broadcasted element. + */ +#define BLIS_CTRSM_SMALL_GEMM_1nx8m(a01,b10,cs_b,p_lda,k_iter){\ + float *tptr = (float *)a01;\ + if(conjtransa) {\ + ymm18 = _mm256_setr_ps(-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0);\ + for(k = 0; k< k_iter; k++) \ + { \ + ymm0 = _mm256_loadu_ps((float const *)(b10)); \ + ymm1 = _mm256_loadu_ps((float const *)(b10 + 4)); \ + \ + ymm2 = _mm256_broadcast_ss(tptr + p_lda * 0 + 0);\ + ymm3 = _mm256_broadcast_ss(tptr + p_lda * 0 + 1);\ + ymm3 = _mm256_mul_ps(ymm3, ymm18);\ + \ + _mm_prefetch((char*)( b10 + 8*cs_b), _MM_HINT_T0); \ + ymm8 = _mm256_fmadd_ps(ymm0, ymm2, ymm8);\ + ymm9 = _mm256_fmadd_ps(ymm1, ymm2, ymm9);\ + ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);\ + ymm5 = _mm256_fmadd_ps(ymm1, ymm3, ymm5);\ + \ + tptr += 2;\ + b10 += cs_b;\ + }\ + }\ + else {\ + for(k = 0; k< k_iter; k++) \ + { \ + ymm0 = _mm256_loadu_ps((float const *)(b10)); \ + ymm1 = _mm256_loadu_ps((float const *)(b10 + 4)); \ + \ + ymm2 = _mm256_broadcast_ss(tptr + p_lda * 0 + 0);\ + ymm3 = _mm256_broadcast_ss(tptr + p_lda * 0 + 1);\ + \ + _mm_prefetch((char*)( b10 + 8*cs_b), _MM_HINT_T0); \ + ymm8 = _mm256_fmadd_ps(ymm0, ymm2, ymm8);\ + ymm9 = _mm256_fmadd_ps(ymm1, ymm2, ymm9);\ + ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);\ + ymm5 = _mm256_fmadd_ps(ymm1, ymm3, ymm5);\ + \ + tptr += 2;\ + b10 += cs_b;\ + }\ + }\ + ymm4 = _mm256_permute_ps(ymm4, 0xb1);\ + ymm5 = _mm256_permute_ps(ymm5, 0xb1);\ + \ + ymm8 = _mm256_addsub_ps(ymm8, ymm4);\ + ymm9 = _mm256_addsub_ps(ymm9, ymm5);\ } - /** * Performs GEMM operation. - * Four elements of column in ymm0 + * Eight elements of column in ymm0 ymm1 * ymm1, ymm2 holds respective broadcasted element. */ -#define BLIS_CTRSM_SMALL_GEMM_2nx4m(a01,b10,cs_b,p_lda,k_iter){\ +#define BLIS_CTRSM_SMALL_GEMM_2nx8m(a01,b10,cs_b,p_lda,k_iter){\ float *tptr = (float *)a01;\ if(conjtransa) {\ ymm18 = _mm256_setr_ps(-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0);\ for(k = 0; k< k_iter; k++) \ { \ ymm0 = _mm256_loadu_ps((float const *)(b10)); \ + ymm1 = _mm256_loadu_ps((float const *)(b10 + 4)); \ \ ymm2 = _mm256_broadcast_ss(tptr + p_lda * 0 + 0);\ ymm3 = _mm256_broadcast_ss(tptr + p_lda * 0 + 1);\ ymm3 = _mm256_mul_ps(ymm3, ymm18);\ \ - _mm_prefetch((char*)( b10 + 2*cs_b), _MM_HINT_T0); \ + _mm_prefetch((char*)( b10 + 8*cs_b), _MM_HINT_T0); \ ymm8 = _mm256_fmadd_ps(ymm0, ymm2, ymm8);\ + ymm9 = _mm256_fmadd_ps(ymm1, ymm2, ymm9);\ ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);\ + ymm5 = _mm256_fmadd_ps(ymm1, ymm3, ymm5);\ \ ymm2 = _mm256_broadcast_ss(tptr + p_lda * 1 * 2+ 0);\ ymm3 = _mm256_broadcast_ss(tptr + p_lda * 1 * 2 + 1);\ \ ymm3 = _mm256_mul_ps(ymm3, ymm18);\ ymm10 = _mm256_fmadd_ps(ymm0, ymm2, ymm10);\ + ymm11 = _mm256_fmadd_ps(ymm1, ymm2, ymm11);\ ymm6 = _mm256_fmadd_ps(ymm0, ymm3, ymm6);\ + ymm7 = _mm256_fmadd_ps(ymm1, ymm3, ymm7);\ \ tptr += 2;\ b10 += cs_b;\ @@ -37689,37 +40557,46 @@ BLIS_INLINE void ctrsm_small_pack_diag_element for(k = 0; k< k_iter; k++) \ { \ ymm0 = _mm256_loadu_ps((float const *)(b10)); \ + ymm1 = _mm256_loadu_ps((float const *)(b10 + 4)); \ \ ymm2 = _mm256_broadcast_ss(tptr + p_lda * 0 + 0);\ ymm3 = _mm256_broadcast_ss(tptr + p_lda * 0 + 1);\ \ - _mm_prefetch((char*)( b10 + 2*cs_b), _MM_HINT_T0); \ + _mm_prefetch((char*)( b10 + 8*cs_b), _MM_HINT_T0); \ ymm8 = _mm256_fmadd_ps(ymm0, ymm2, ymm8);\ + ymm9 = _mm256_fmadd_ps(ymm1, ymm2, ymm9);\ ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);\ + ymm5 = _mm256_fmadd_ps(ymm1, ymm3, ymm5);\ \ ymm2 = _mm256_broadcast_ss(tptr + p_lda * 1 * 2+ 0);\ ymm3 = _mm256_broadcast_ss(tptr + p_lda * 1 * 2 + 1);\ \ ymm10 = _mm256_fmadd_ps(ymm0, ymm2, ymm10);\ + ymm11 = _mm256_fmadd_ps(ymm1, ymm2, ymm11);\ ymm6 = _mm256_fmadd_ps(ymm0, ymm3, ymm6);\ + ymm7 = _mm256_fmadd_ps(ymm1, ymm3, ymm7);\ \ tptr += 2;\ b10 += cs_b;\ }\ }\ ymm4 = _mm256_permute_ps(ymm4, 0xb1);\ + ymm5 = _mm256_permute_ps(ymm5, 0xb1);\ ymm6 = _mm256_permute_ps(ymm6, 0xb1);\ + ymm7 = _mm256_permute_ps(ymm7, 0xb1);\ \ ymm8 = _mm256_addsub_ps(ymm8, ymm4);\ + ymm9 = _mm256_addsub_ps(ymm9, ymm5);\ ymm10 = _mm256_addsub_ps(ymm10, ymm6);\ + ymm11 = _mm256_addsub_ps(ymm11, ymm7);\ } /** * Performs GEMM operation. * Four elements of column in ymm0 - * ymm1 holds respective broadcasted element. + * ymm1, ymm2 holds respective broadcasted element. */ -#define BLIS_CTRSM_SMALL_GEMM_1nx4m(a01,b10,cs_b,p_lda,k_iter){\ +#define BLIS_CTRSM_SMALL_GEMM_3nx4m(a01,b10,cs_b,p_lda,k_iter) {\ float *tptr = (float *)a01;\ if(conjtransa) {\ ymm18 = _mm256_setr_ps(-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0);\ @@ -37735,6 +40612,20 @@ BLIS_INLINE void ctrsm_small_pack_diag_element ymm8 = _mm256_fmadd_ps(ymm0, ymm2, ymm8);\ ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);\ \ + ymm2 = _mm256_broadcast_ss(tptr + p_lda * 1 * 2+ 0);\ + ymm3 = _mm256_broadcast_ss(tptr + p_lda * 1 * 2 + 1);\ + \ + ymm3 = _mm256_mul_ps(ymm3, ymm18);\ + ymm10 = _mm256_fmadd_ps(ymm0, ymm2, ymm10);\ + ymm6 = _mm256_fmadd_ps(ymm0, ymm3, ymm6);\ + \ + ymm2 = _mm256_broadcast_ss(tptr + p_lda * 2 * 2+ 0);\ + ymm3 = _mm256_broadcast_ss(tptr + p_lda * 2 * 2 + 1);\ + \ + ymm3 = _mm256_mul_ps(ymm3, ymm18);\ + ymm12 = _mm256_fmadd_ps(ymm0, ymm2, ymm12);\ + ymm14 = _mm256_fmadd_ps(ymm0, ymm3, ymm14);\ + \ tptr += 2;\ b10 += cs_b;\ }\ @@ -37744,45 +40635,70 @@ BLIS_INLINE void ctrsm_small_pack_diag_element { \ ymm0 = _mm256_loadu_ps((float const *)(b10)); \ \ + _mm_prefetch((char*)( b10 + 4*cs_b), _MM_HINT_T0); \ ymm2 = _mm256_broadcast_ss(tptr + p_lda * 0 + 0);\ ymm3 = _mm256_broadcast_ss(tptr + p_lda * 0 + 1);\ \ - _mm_prefetch((char*)( b10 + 4*cs_b), _MM_HINT_T0); \ ymm8 = _mm256_fmadd_ps(ymm0, ymm2, ymm8);\ ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);\ \ + ymm2 = _mm256_broadcast_ss(tptr + p_lda * 1 * 2+ 0);\ + ymm3 = _mm256_broadcast_ss(tptr + p_lda * 1 * 2 + 1);\ + \ + ymm10 = _mm256_fmadd_ps(ymm0, ymm2, ymm10);\ + ymm6 = _mm256_fmadd_ps(ymm0, ymm3, ymm6);\ + \ + ymm2 = _mm256_broadcast_ss(tptr + p_lda * 2 * 2+ 0);\ + ymm3 = _mm256_broadcast_ss(tptr + p_lda * 2 * 2 + 1);\ + \ + ymm12 = _mm256_fmadd_ps(ymm0, ymm2, ymm12);\ + ymm14 = _mm256_fmadd_ps(ymm0, ymm3, ymm14);\ + \ tptr += 2;\ b10 += cs_b;\ }\ }\ ymm4 = _mm256_permute_ps(ymm4, 0xb1);\ + ymm6 = _mm256_permute_ps(ymm6, 0xb1);\ + ymm14 = _mm256_permute_ps(ymm14, 0xb1);\ \ ymm8 = _mm256_addsub_ps(ymm8, ymm4);\ + ymm10 = _mm256_addsub_ps(ymm10, ymm6);\ + ymm12 = _mm256_addsub_ps(ymm12, ymm14);\ } -/** - * Performs GEMM operation. - * Eight elements of column in ymm0, ymm1 - * ymm1 holds respective broadcasted element. - */ -#define BLIS_CTRSM_SMALL_GEMM_1nx8m(a01,b10,cs_b,p_lda,k_iter){\ +#define BLIS_CTRSM_SMALL_GEMM_3nx3m(a01,b10,cs_b,p_lda,k_iter) {\ float *tptr = (float *)a01;\ if(conjtransa) {\ ymm18 = _mm256_setr_ps(-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0);\ for(k = 0; k< k_iter; k++) \ { \ - ymm0 = _mm256_loadu_ps((float const *)(b10)); \ - ymm1 = _mm256_loadu_ps((float const *)(b10 + 4)); \ + xmm5 = _mm_loadu_ps((float const *)(b10));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ + xmm5 = _mm_loadl_pi(xmm5,(__m64 *)(b10 + 2));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 1);\ \ ymm2 = _mm256_broadcast_ss(tptr + p_lda * 0 + 0);\ ymm3 = _mm256_broadcast_ss(tptr + p_lda * 0 + 1);\ ymm3 = _mm256_mul_ps(ymm3, ymm18);\ \ - _mm_prefetch((char*)( b10 + 8*cs_b), _MM_HINT_T0); \ + _mm_prefetch((char*)( b10 + 4*cs_b), _MM_HINT_T0); \ ymm8 = _mm256_fmadd_ps(ymm0, ymm2, ymm8);\ - ymm9 = _mm256_fmadd_ps(ymm1, ymm2, ymm9);\ ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);\ - ymm5 = _mm256_fmadd_ps(ymm1, ymm3, ymm5);\ + \ + ymm2 = _mm256_broadcast_ss(tptr + p_lda * 1 * 2+ 0);\ + ymm3 = _mm256_broadcast_ss(tptr + p_lda * 1 * 2 + 1);\ + \ + ymm3 = _mm256_mul_ps(ymm3, ymm18);\ + ymm10 = _mm256_fmadd_ps(ymm0, ymm2, ymm10);\ + ymm6 = _mm256_fmadd_ps(ymm0, ymm3, ymm6);\ + \ + ymm2 = _mm256_broadcast_ss(tptr + p_lda * 2 * 2+ 0);\ + ymm3 = _mm256_broadcast_ss(tptr + p_lda * 2 * 2 + 1);\ + \ + ymm3 = _mm256_mul_ps(ymm3, ymm18);\ + ymm12 = _mm256_fmadd_ps(ymm0, ymm2, ymm12);\ + ymm14 = _mm256_fmadd_ps(ymm0, ymm3, ymm14);\ \ tptr += 2;\ b10 += cs_b;\ @@ -37791,61 +40707,73 @@ BLIS_INLINE void ctrsm_small_pack_diag_element else {\ for(k = 0; k< k_iter; k++) \ { \ - ymm0 = _mm256_loadu_ps((float const *)(b10)); \ - ymm1 = _mm256_loadu_ps((float const *)(b10 + 4)); \ + xmm5 = _mm_loadu_ps((float const *)(b10));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ + xmm5 = _mm_loadl_pi(xmm5,(__m64 *)(b10 + 2));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 1);\ \ + _mm_prefetch((char*)( b10 + 4*cs_b), _MM_HINT_T0); \ ymm2 = _mm256_broadcast_ss(tptr + p_lda * 0 + 0);\ ymm3 = _mm256_broadcast_ss(tptr + p_lda * 0 + 1);\ \ - _mm_prefetch((char*)( b10 + 8*cs_b), _MM_HINT_T0); \ ymm8 = _mm256_fmadd_ps(ymm0, ymm2, ymm8);\ - ymm9 = _mm256_fmadd_ps(ymm1, ymm2, ymm9);\ ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);\ - ymm5 = _mm256_fmadd_ps(ymm1, ymm3, ymm5);\ + \ + ymm2 = _mm256_broadcast_ss(tptr + p_lda * 1 * 2+ 0);\ + ymm3 = _mm256_broadcast_ss(tptr + p_lda * 1 * 2 + 1);\ + \ + ymm10 = _mm256_fmadd_ps(ymm0, ymm2, ymm10);\ + ymm6 = _mm256_fmadd_ps(ymm0, ymm3, ymm6);\ + \ + ymm2 = _mm256_broadcast_ss(tptr + p_lda * 2 * 2+ 0);\ + ymm3 = _mm256_broadcast_ss(tptr + p_lda * 2 * 2 + 1);\ + \ + ymm12 = _mm256_fmadd_ps(ymm0, ymm2, ymm12);\ + ymm14 = _mm256_fmadd_ps(ymm0, ymm3, ymm14);\ \ tptr += 2;\ b10 += cs_b;\ }\ }\ ymm4 = _mm256_permute_ps(ymm4, 0xb1);\ - ymm5 = _mm256_permute_ps(ymm5, 0xb1);\ + ymm6 = _mm256_permute_ps(ymm6, 0xb1);\ + ymm14 = _mm256_permute_ps(ymm14, 0xb1);\ \ ymm8 = _mm256_addsub_ps(ymm8, ymm4);\ - ymm9 = _mm256_addsub_ps(ymm9, ymm5);\ + ymm10 = _mm256_addsub_ps(ymm10, ymm6);\ + ymm12 = _mm256_addsub_ps(ymm12, ymm14);\ } -/** - * Performs GEMM operation. - * Eight elements of column in ymm0 ymm1 - * ymm1, ymm2 holds respective broadcasted element. - */ -#define BLIS_CTRSM_SMALL_GEMM_2nx8m(a01,b10,cs_b,p_lda,k_iter){\ +#define BLIS_CTRSM_SMALL_GEMM_3nx2m(a01,b10,cs_b,p_lda,k_iter) {\ float *tptr = (float *)a01;\ if(conjtransa) {\ ymm18 = _mm256_setr_ps(-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0);\ for(k = 0; k< k_iter; k++) \ { \ - ymm0 = _mm256_loadu_ps((float const *)(b10)); \ - ymm1 = _mm256_loadu_ps((float const *)(b10 + 4)); \ + xmm5 = _mm_loadu_ps((float const *)(b10));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ \ ymm2 = _mm256_broadcast_ss(tptr + p_lda * 0 + 0);\ ymm3 = _mm256_broadcast_ss(tptr + p_lda * 0 + 1);\ ymm3 = _mm256_mul_ps(ymm3, ymm18);\ \ - _mm_prefetch((char*)( b10 + 8*cs_b), _MM_HINT_T0); \ + _mm_prefetch((char*)( b10 + 4*cs_b), _MM_HINT_T0); \ ymm8 = _mm256_fmadd_ps(ymm0, ymm2, ymm8);\ - ymm9 = _mm256_fmadd_ps(ymm1, ymm2, ymm9);\ ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);\ - ymm5 = _mm256_fmadd_ps(ymm1, ymm3, ymm5);\ \ ymm2 = _mm256_broadcast_ss(tptr + p_lda * 1 * 2+ 0);\ ymm3 = _mm256_broadcast_ss(tptr + p_lda * 1 * 2 + 1);\ \ ymm3 = _mm256_mul_ps(ymm3, ymm18);\ ymm10 = _mm256_fmadd_ps(ymm0, ymm2, ymm10);\ - ymm11 = _mm256_fmadd_ps(ymm1, ymm2, ymm11);\ ymm6 = _mm256_fmadd_ps(ymm0, ymm3, ymm6);\ - ymm7 = _mm256_fmadd_ps(ymm1, ymm3, ymm7);\ + \ + ymm2 = _mm256_broadcast_ss(tptr + p_lda * 2 * 2+ 0);\ + ymm3 = _mm256_broadcast_ss(tptr + p_lda * 2 * 2 + 1);\ + \ + ymm3 = _mm256_mul_ps(ymm3, ymm18);\ + ymm12 = _mm256_fmadd_ps(ymm0, ymm2, ymm12);\ + ymm14 = _mm256_fmadd_ps(ymm0, ymm3, ymm14);\ \ tptr += 2;\ b10 += cs_b;\ @@ -37854,54 +40782,50 @@ BLIS_INLINE void ctrsm_small_pack_diag_element else {\ for(k = 0; k< k_iter; k++) \ { \ - ymm0 = _mm256_loadu_ps((float const *)(b10)); \ - ymm1 = _mm256_loadu_ps((float const *)(b10 + 4)); \ + xmm5 = _mm_loadu_ps((float const *)(b10));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ \ + _mm_prefetch((char*)( b10 + 4*cs_b), _MM_HINT_T0); \ ymm2 = _mm256_broadcast_ss(tptr + p_lda * 0 + 0);\ ymm3 = _mm256_broadcast_ss(tptr + p_lda * 0 + 1);\ \ - _mm_prefetch((char*)( b10 + 8*cs_b), _MM_HINT_T0); \ ymm8 = _mm256_fmadd_ps(ymm0, ymm2, ymm8);\ - ymm9 = _mm256_fmadd_ps(ymm1, ymm2, ymm9);\ ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);\ - ymm5 = _mm256_fmadd_ps(ymm1, ymm3, ymm5);\ \ ymm2 = _mm256_broadcast_ss(tptr + p_lda * 1 * 2+ 0);\ ymm3 = _mm256_broadcast_ss(tptr + p_lda * 1 * 2 + 1);\ \ ymm10 = _mm256_fmadd_ps(ymm0, ymm2, ymm10);\ - ymm11 = _mm256_fmadd_ps(ymm1, ymm2, ymm11);\ ymm6 = _mm256_fmadd_ps(ymm0, ymm3, ymm6);\ - ymm7 = _mm256_fmadd_ps(ymm1, ymm3, ymm7);\ + \ + ymm2 = _mm256_broadcast_ss(tptr + p_lda * 2 * 2+ 0);\ + ymm3 = _mm256_broadcast_ss(tptr + p_lda * 2 * 2 + 1);\ + \ + ymm12 = _mm256_fmadd_ps(ymm0, ymm2, ymm12);\ + ymm14 = _mm256_fmadd_ps(ymm0, ymm3, ymm14);\ \ tptr += 2;\ b10 += cs_b;\ }\ }\ ymm4 = _mm256_permute_ps(ymm4, 0xb1);\ - ymm5 = _mm256_permute_ps(ymm5, 0xb1);\ ymm6 = _mm256_permute_ps(ymm6, 0xb1);\ - ymm7 = _mm256_permute_ps(ymm7, 0xb1);\ + ymm14 = _mm256_permute_ps(ymm14, 0xb1);\ \ ymm8 = _mm256_addsub_ps(ymm8, ymm4);\ - ymm9 = _mm256_addsub_ps(ymm9, ymm5);\ ymm10 = _mm256_addsub_ps(ymm10, ymm6);\ - ymm11 = _mm256_addsub_ps(ymm11, ymm7);\ + ymm12 = _mm256_addsub_ps(ymm12, ymm14);\ } -/** - * Performs GEMM operation. - * Four elements of column in ymm0 - * ymm1, ymm2 holds respective broadcasted element. - */ -#define BLIS_CTRSM_SMALL_GEMM_3nx4m(a01,b10,cs_b,p_lda,k_iter) {\ +#define BLIS_CTRSM_SMALL_GEMM_3nx1m(a01,b10,cs_b,p_lda,k_iter) {\ float *tptr = (float *)a01;\ if(conjtransa) {\ ymm18 = _mm256_setr_ps(-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0);\ for(k = 0; k< k_iter; k++) \ { \ - ymm0 = _mm256_loadu_ps((float const *)(b10)); \ - \ + xmm5 = _mm_loadl_pi(xmm5,(__m64 *)(b10));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ + \ ymm2 = _mm256_broadcast_ss(tptr + p_lda * 0 + 0);\ ymm3 = _mm256_broadcast_ss(tptr + p_lda * 0 + 1);\ ymm3 = _mm256_mul_ps(ymm3, ymm18);\ @@ -37931,8 +40855,9 @@ BLIS_INLINE void ctrsm_small_pack_diag_element else {\ for(k = 0; k< k_iter; k++) \ { \ - ymm0 = _mm256_loadu_ps((float const *)(b10)); \ - \ + xmm5 = _mm_loadl_pi(xmm5,(__m64 *)(b10));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ + \ _mm_prefetch((char*)( b10 + 4*cs_b), _MM_HINT_T0); \ ymm2 = _mm256_broadcast_ss(tptr + p_lda * 0 + 0);\ ymm3 = _mm256_broadcast_ss(tptr + p_lda * 0 + 1);\ @@ -37965,7 +40890,6 @@ BLIS_INLINE void ctrsm_small_pack_diag_element ymm12 = _mm256_addsub_ps(ymm12, ymm14);\ } - /** * Performs GEMM operation. * Eight elements of column in ymm0, ymm1 @@ -38216,37 +41140,413 @@ BLIS_INLINE void ctrsm_small_pack_diag_element ymm9 = _mm256_addsub_ps(ymm9, ymm6);\ } -/** - * Performs GEMM operation. - * Eight elements of column in ymm0, ymm1 - * ymm1, ymm2 holds respective broadcasted element. - */ -#define BLIS_CTRSM_SMALL_GEMM_8mx2n(a10,b01,cs_b,p_lda,k_iter) {\ +/** + * Performs GEMM operation. + * Eight elements of column in ymm0, ymm1 + * ymm1, ymm2 holds respective broadcasted element. + */ +#define BLIS_CTRSM_SMALL_GEMM_8mx2n(a10,b01,cs_b,p_lda,k_iter) {\ + float *tptr = (float *)b01;\ + if(conjtransa) {\ + ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0);\ + for(k = 0; k< k_iter; k++) \ + { \ + ymm0 = _mm256_loadu_ps((float const *)(a10)); \ + ymm1 = _mm256_loadu_ps((float const *)(a10 + 4)); \ + ymm0 = _mm256_mul_ps(ymm0, ymm18);\ + ymm1 = _mm256_mul_ps(ymm1, ymm18);\ + \ + ymm2 = _mm256_broadcast_ss(tptr + cs_b * 0 + 0);\ + ymm3 = _mm256_broadcast_ss(tptr + cs_b * 0 + 1);\ + \ + ymm8 = _mm256_fmadd_ps(ymm0, ymm2, ymm8);\ + ymm12 = _mm256_fmadd_ps(ymm1, ymm2, ymm12);\ + ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);\ + ymm5 = _mm256_fmadd_ps(ymm1, ymm3, ymm5);\ + \ + ymm2 = _mm256_broadcast_ss(tptr + cs_b * 1 * 2+ 0);\ + ymm3 = _mm256_broadcast_ss(tptr + cs_b * 1 * 2 + 1);\ + \ + ymm9 = _mm256_fmadd_ps(ymm0, ymm2, ymm9);\ + ymm13 = _mm256_fmadd_ps(ymm1, ymm2, ymm13);\ + ymm6 = _mm256_fmadd_ps(ymm0, ymm3, ymm6);\ + ymm7 = _mm256_fmadd_ps(ymm1, ymm3, ymm7);\ + \ + tptr += 2;\ + a10 += p_lda;\ + }\ + }\ + else {\ + for(k = 0; k< k_iter; k++) \ + { \ + ymm0 = _mm256_loadu_ps((float const *)(a10)); \ + ymm1 = _mm256_loadu_ps((float const *)(a10 + 4)); \ + \ + ymm2 = _mm256_broadcast_ss(tptr + cs_b * 0 + 0);\ + ymm3 = _mm256_broadcast_ss(tptr + cs_b * 0 + 1);\ + \ + ymm8 = _mm256_fmadd_ps(ymm0, ymm2, ymm8);\ + ymm12 = _mm256_fmadd_ps(ymm1, ymm2, ymm12);\ + ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);\ + ymm5 = _mm256_fmadd_ps(ymm1, ymm3, ymm5);\ + \ + ymm2 = _mm256_broadcast_ss(tptr + cs_b * 1 * 2+ 0);\ + ymm3 = _mm256_broadcast_ss(tptr + cs_b * 1 * 2 + 1);\ + \ + ymm9 = _mm256_fmadd_ps(ymm0, ymm2, ymm9);\ + ymm13 = _mm256_fmadd_ps(ymm1, ymm2, ymm13);\ + ymm6 = _mm256_fmadd_ps(ymm0, ymm3, ymm6);\ + ymm7 = _mm256_fmadd_ps(ymm1, ymm3, ymm7);\ + \ + tptr += 2;\ + a10 += p_lda;\ + }\ + }\ + ymm4 = _mm256_permute_ps(ymm4, 0xb1);\ + ymm5 = _mm256_permute_ps(ymm5, 0xb1);\ + ymm6 = _mm256_permute_ps(ymm6, 0xb1);\ + ymm7 = _mm256_permute_ps(ymm7, 0xb1);\ + \ + ymm8 = _mm256_addsub_ps(ymm8, ymm4);\ + ymm12 = _mm256_addsub_ps(ymm12, ymm5);\ + ymm9 = _mm256_addsub_ps(ymm9, ymm6);\ + ymm13 = _mm256_addsub_ps(ymm13, ymm7);\ +} + +#define BLIS_CTRSM_SMALL_GEMM_4mx3n(a10,b01,cs_b,p_lda,k_iter) {\ + float *tptr = (float *)b01;\ + if(conjtransa) {\ + ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0);\ + for(k = 0; k< k_iter; k++) \ + { \ + ymm0 = _mm256_loadu_ps((float const *)(a10)); \ + ymm0 = _mm256_mul_ps(ymm0, ymm18);\ + \ + ymm2 = _mm256_broadcast_ss(tptr + cs_b * 0 + 0);\ + ymm3 = _mm256_broadcast_ss(tptr + cs_b * 0 + 1);\ + \ + ymm8 = _mm256_fmadd_ps(ymm0, ymm2, ymm8);\ + ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);\ + \ + ymm2 = _mm256_broadcast_ss(tptr + cs_b * 1 * 2+ 0);\ + ymm3 = _mm256_broadcast_ss(tptr + cs_b * 1 * 2 + 1);\ + \ + ymm9 = _mm256_fmadd_ps(ymm0, ymm2, ymm9);\ + ymm6 = _mm256_fmadd_ps(ymm0, ymm3, ymm6);\ + \ + ymm2 = _mm256_broadcast_ss(tptr + cs_b * 2 * 2+ 0);\ + ymm3 = _mm256_broadcast_ss(tptr + cs_b * 2 * 2 + 1);\ + \ + ymm10 = _mm256_fmadd_ps(ymm0, ymm2, ymm10);\ + ymm11 = _mm256_fmadd_ps(ymm0, ymm3, ymm11);\ + \ + tptr += 2;\ + a10 += p_lda;\ + }\ + }\ + else {\ + for(k = 0; k< k_iter; k++) \ + { \ + ymm0 = _mm256_loadu_ps((float const *)(a10)); \ + \ + ymm2 = _mm256_broadcast_ss(tptr + cs_b * 0 + 0);\ + ymm3 = _mm256_broadcast_ss(tptr + cs_b * 0 + 1);\ + \ + ymm8 = _mm256_fmadd_ps(ymm0, ymm2, ymm8);\ + ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);\ + \ + ymm2 = _mm256_broadcast_ss(tptr + cs_b * 1 * 2+ 0);\ + ymm3 = _mm256_broadcast_ss(tptr + cs_b * 1 * 2 + 1);\ + \ + ymm9 = _mm256_fmadd_ps(ymm0, ymm2, ymm9);\ + ymm6 = _mm256_fmadd_ps(ymm0, ymm3, ymm6);\ + \ + ymm2 = _mm256_broadcast_ss(tptr + cs_b * 2 * 2+ 0);\ + ymm3 = _mm256_broadcast_ss(tptr + cs_b * 2 * 2 + 1);\ + \ + ymm10 = _mm256_fmadd_ps(ymm0, ymm2, ymm10);\ + ymm11 = _mm256_fmadd_ps(ymm0, ymm3, ymm11);\ + \ + tptr += 2;\ + a10 += p_lda;\ + }\ + }\ + ymm4 = _mm256_permute_ps(ymm4, 0xb1);\ + ymm6 = _mm256_permute_ps(ymm6, 0xb1);\ + ymm11 = _mm256_permute_ps(ymm11, 0xb1);\ + \ + ymm8 = _mm256_addsub_ps(ymm8, ymm4);\ + ymm9 = _mm256_addsub_ps(ymm9, ymm6);\ + ymm10 = _mm256_addsub_ps(ymm10, ymm11);\ +} + +#define BLIS_CTRSM_SMALL_GEMM_2mx3n(a10,b01,cs_b,p_lda,k_iter) {\ + float *tptr = (float *)b01;\ + if(conjtransa) {\ + ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0);\ + for(k = 0; k< k_iter; k++) \ + { \ + xmm5 = _mm_loadu_ps((float const *)(a10));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ + ymm0 = _mm256_mul_ps(ymm0, ymm18);\ + \ + ymm2 = _mm256_broadcast_ss(tptr + cs_b * 0 + 0);\ + ymm3 = _mm256_broadcast_ss(tptr + cs_b * 0 + 1);\ + \ + ymm8 = _mm256_fmadd_ps(ymm0, ymm2, ymm8);\ + ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);\ + \ + ymm2 = _mm256_broadcast_ss(tptr + cs_b * 1 * 2+ 0);\ + ymm3 = _mm256_broadcast_ss(tptr + cs_b * 1 * 2 + 1);\ + \ + ymm9 = _mm256_fmadd_ps(ymm0, ymm2, ymm9);\ + ymm6 = _mm256_fmadd_ps(ymm0, ymm3, ymm6);\ + \ + ymm2 = _mm256_broadcast_ss(tptr + cs_b * 2 * 2+ 0);\ + ymm3 = _mm256_broadcast_ss(tptr + cs_b * 2 * 2 + 1);\ + \ + ymm10 = _mm256_fmadd_ps(ymm0, ymm2, ymm10);\ + ymm11 = _mm256_fmadd_ps(ymm0, ymm3, ymm11);\ + \ + tptr += 2;\ + a10 += p_lda;\ + }\ + }\ + else {\ + for(k = 0; k< k_iter; k++) \ + { \ + xmm5 = _mm_loadu_ps((float const *)(a10));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ + \ + ymm2 = _mm256_broadcast_ss(tptr + cs_b * 0 + 0);\ + ymm3 = _mm256_broadcast_ss(tptr + cs_b * 0 + 1);\ + \ + ymm8 = _mm256_fmadd_ps(ymm0, ymm2, ymm8);\ + ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);\ + \ + ymm2 = _mm256_broadcast_ss(tptr + cs_b * 1 * 2+ 0);\ + ymm3 = _mm256_broadcast_ss(tptr + cs_b * 1 * 2 + 1);\ + \ + ymm9 = _mm256_fmadd_ps(ymm0, ymm2, ymm9);\ + ymm6 = _mm256_fmadd_ps(ymm0, ymm3, ymm6);\ + \ + ymm2 = _mm256_broadcast_ss(tptr + cs_b * 2 * 2+ 0);\ + ymm3 = _mm256_broadcast_ss(tptr + cs_b * 2 * 2 + 1);\ + \ + ymm10 = _mm256_fmadd_ps(ymm0, ymm2, ymm10);\ + ymm11 = _mm256_fmadd_ps(ymm0, ymm3, ymm11);\ + \ + tptr += 2;\ + a10 += p_lda;\ + }\ + }\ + ymm4 = _mm256_permute_ps(ymm4, 0xb1);\ + ymm6 = _mm256_permute_ps(ymm6, 0xb1);\ + ymm11 = _mm256_permute_ps(ymm11, 0xb1);\ + \ + ymm8 = _mm256_addsub_ps(ymm8, ymm4);\ + ymm9 = _mm256_addsub_ps(ymm9, ymm6);\ + ymm10 = _mm256_addsub_ps(ymm10, ymm11);\ +} + +#define BLIS_CTRSM_SMALL_GEMM_3mx3n(a10,b01,cs_b,p_lda,k_iter) {\ + float *tptr = (float *)b01;\ + if(conjtransa) {\ + ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0);\ + for(k = 0; k< k_iter; k++) \ + { \ + xmm5 = _mm_loadu_ps((float const *)(a10));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ + xmm5 = _mm_loadl_pi(xmm5,(__m64 *)(a10 + 2));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 1);\ + ymm0 = _mm256_mul_ps(ymm0, ymm18);\ + \ + ymm2 = _mm256_broadcast_ss(tptr + cs_b * 0 + 0);\ + ymm3 = _mm256_broadcast_ss(tptr + cs_b * 0 + 1);\ + \ + ymm8 = _mm256_fmadd_ps(ymm0, ymm2, ymm8);\ + ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);\ + \ + ymm2 = _mm256_broadcast_ss(tptr + cs_b * 1 * 2+ 0);\ + ymm3 = _mm256_broadcast_ss(tptr + cs_b * 1 * 2 + 1);\ + \ + ymm9 = _mm256_fmadd_ps(ymm0, ymm2, ymm9);\ + ymm6 = _mm256_fmadd_ps(ymm0, ymm3, ymm6);\ + \ + ymm2 = _mm256_broadcast_ss(tptr + cs_b * 2 * 2+ 0);\ + ymm3 = _mm256_broadcast_ss(tptr + cs_b * 2 * 2 + 1);\ + \ + ymm10 = _mm256_fmadd_ps(ymm0, ymm2, ymm10);\ + ymm11 = _mm256_fmadd_ps(ymm0, ymm3, ymm11);\ + \ + tptr += 2;\ + a10 += p_lda;\ + }\ + }\ + else {\ + for(k = 0; k< k_iter; k++) \ + { \ + xmm5 = _mm_loadu_ps((float const *)(a10));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ + xmm5 = _mm_loadl_pi(xmm5,(__m64 *)(a10 + 2));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 1);\ + \ + ymm2 = _mm256_broadcast_ss(tptr + cs_b * 0 + 0);\ + ymm3 = _mm256_broadcast_ss(tptr + cs_b * 0 + 1);\ + \ + ymm8 = _mm256_fmadd_ps(ymm0, ymm2, ymm8);\ + ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);\ + \ + ymm2 = _mm256_broadcast_ss(tptr + cs_b * 1 * 2+ 0);\ + ymm3 = _mm256_broadcast_ss(tptr + cs_b * 1 * 2 + 1);\ + \ + ymm9 = _mm256_fmadd_ps(ymm0, ymm2, ymm9);\ + ymm6 = _mm256_fmadd_ps(ymm0, ymm3, ymm6);\ + \ + ymm2 = _mm256_broadcast_ss(tptr + cs_b * 2 * 2+ 0);\ + ymm3 = _mm256_broadcast_ss(tptr + cs_b * 2 * 2 + 1);\ + \ + ymm10 = _mm256_fmadd_ps(ymm0, ymm2, ymm10);\ + ymm11 = _mm256_fmadd_ps(ymm0, ymm3, ymm11);\ + \ + tptr += 2;\ + a10 += p_lda;\ + }\ + }\ + ymm4 = _mm256_permute_ps(ymm4, 0xb1);\ + ymm6 = _mm256_permute_ps(ymm6, 0xb1);\ + ymm11 = _mm256_permute_ps(ymm11, 0xb1);\ + \ + ymm8 = _mm256_addsub_ps(ymm8, ymm4);\ + ymm9 = _mm256_addsub_ps(ymm9, ymm6);\ + ymm10 = _mm256_addsub_ps(ymm10, ymm11);\ +} + +#define BLIS_CTRSM_SMALL_GEMM_3mx2n(a10,b01,cs_b,p_lda,k_iter) {\ + float *tptr = (float *)b01;\ + if(conjtransa) {\ + ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0);\ + for(k = 0; k< k_iter; k++) \ + { \ + xmm5 = _mm_loadu_ps((float const *)(a10));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ + xmm5 = _mm_loadl_pi(xmm5,(__m64 *)(a10 + 2));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 1);\ + ymm0 = _mm256_mul_ps(ymm0, ymm18);\ + \ + ymm2 = _mm256_broadcast_ss(tptr + cs_b * 0 + 0);\ + ymm3 = _mm256_broadcast_ss(tptr + cs_b * 0 + 1);\ + \ + ymm8 = _mm256_fmadd_ps(ymm0, ymm2, ymm8);\ + ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);\ + \ + ymm2 = _mm256_broadcast_ss(tptr + cs_b * 1 * 2+ 0);\ + ymm3 = _mm256_broadcast_ss(tptr + cs_b * 1 * 2 + 1);\ + \ + ymm9 = _mm256_fmadd_ps(ymm0, ymm2, ymm9);\ + ymm6 = _mm256_fmadd_ps(ymm0, ymm3, ymm6);\ + \ + tptr += 2;\ + a10 += p_lda;\ + }\ + }\ + else {\ + for(k = 0; k< k_iter; k++) \ + { \ + xmm5 = _mm_loadu_ps((float const *)(a10));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ + xmm5 = _mm_loadl_pi(xmm5,(__m64 *)(a10 + 2));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 1);\ + \ + ymm2 = _mm256_broadcast_ss(tptr + cs_b * 0 + 0);\ + ymm3 = _mm256_broadcast_ss(tptr + cs_b * 0 + 1);\ + \ + ymm8 = _mm256_fmadd_ps(ymm0, ymm2, ymm8);\ + ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);\ + \ + ymm2 = _mm256_broadcast_ss(tptr + cs_b * 1 * 2+ 0);\ + ymm3 = _mm256_broadcast_ss(tptr + cs_b * 1 * 2 + 1);\ + \ + ymm9 = _mm256_fmadd_ps(ymm0, ymm2, ymm9);\ + ymm6 = _mm256_fmadd_ps(ymm0, ymm3, ymm6);\ + \ + tptr += 2;\ + a10 += p_lda;\ + }\ + }\ + ymm4 = _mm256_permute_ps(ymm4, 0xb1);\ + ymm6 = _mm256_permute_ps(ymm6, 0xb1);\ + \ + ymm8 = _mm256_addsub_ps(ymm8, ymm4);\ + ymm9 = _mm256_addsub_ps(ymm9, ymm6);\ +} + +#define BLIS_CTRSM_SMALL_GEMM_3mx1n(a10,b01,cs_b,p_lda,k_iter) {\ float *tptr = (float *)b01;\ if(conjtransa) {\ ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0);\ for(k = 0; k< k_iter; k++) \ { \ - ymm0 = _mm256_loadu_ps((float const *)(a10)); \ - ymm1 = _mm256_loadu_ps((float const *)(a10 + 4)); \ + xmm5 = _mm_loadu_ps((float const *)(a10));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ + xmm5 = _mm_loadl_pi(xmm5,(__m64 *)(a10 + 2));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 1);\ + ymm0 = _mm256_mul_ps(ymm0, ymm18);\ + \ + ymm2 = _mm256_broadcast_ss(tptr + cs_b * 0 + 0);\ + ymm3 = _mm256_broadcast_ss(tptr + cs_b * 0 + 1);\ + \ + ymm8 = _mm256_fmadd_ps(ymm0, ymm2, ymm8);\ + ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);\ + \ + tptr += 2;\ + a10 += p_lda;\ + }\ + }\ + else {\ + for(k = 0; k< k_iter; k++) \ + { \ + xmm5 = _mm_loadu_ps((float const *)(a10));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ + xmm5 = _mm_loadl_pi(xmm5,(__m64 *)(a10 + 2));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 1);\ + \ + ymm2 = _mm256_broadcast_ss(tptr + cs_b * 0 + 0);\ + ymm3 = _mm256_broadcast_ss(tptr + cs_b * 0 + 1);\ + \ + ymm8 = _mm256_fmadd_ps(ymm0, ymm2, ymm8);\ + ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);\ + \ + tptr += 2;\ + a10 += p_lda;\ + }\ + }\ + ymm4 = _mm256_permute_ps(ymm4, 0xb1);\ + \ + ymm8 = _mm256_addsub_ps(ymm8, ymm4);\ +} + +#define BLIS_CTRSM_SMALL_GEMM_2mx2n(a10,b01,cs_b,p_lda,k_iter) {\ + float *tptr = (float *)b01;\ + if(conjtransa) {\ + ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0);\ + for(k = 0; k< k_iter; k++) \ + { \ + xmm5 = _mm_loadu_ps((float const *)(a10));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ ymm0 = _mm256_mul_ps(ymm0, ymm18);\ - ymm1 = _mm256_mul_ps(ymm1, ymm18);\ \ ymm2 = _mm256_broadcast_ss(tptr + cs_b * 0 + 0);\ ymm3 = _mm256_broadcast_ss(tptr + cs_b * 0 + 1);\ \ ymm8 = _mm256_fmadd_ps(ymm0, ymm2, ymm8);\ - ymm12 = _mm256_fmadd_ps(ymm1, ymm2, ymm12);\ ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);\ - ymm5 = _mm256_fmadd_ps(ymm1, ymm3, ymm5);\ \ ymm2 = _mm256_broadcast_ss(tptr + cs_b * 1 * 2+ 0);\ ymm3 = _mm256_broadcast_ss(tptr + cs_b * 1 * 2 + 1);\ \ ymm9 = _mm256_fmadd_ps(ymm0, ymm2, ymm9);\ - ymm13 = _mm256_fmadd_ps(ymm1, ymm2, ymm13);\ ymm6 = _mm256_fmadd_ps(ymm0, ymm3, ymm6);\ - ymm7 = _mm256_fmadd_ps(ymm1, ymm3, ymm7);\ \ tptr += 2;\ a10 += p_lda;\ @@ -38255,47 +41555,81 @@ BLIS_INLINE void ctrsm_small_pack_diag_element else {\ for(k = 0; k< k_iter; k++) \ { \ - ymm0 = _mm256_loadu_ps((float const *)(a10)); \ - ymm1 = _mm256_loadu_ps((float const *)(a10 + 4)); \ + xmm5 = _mm_loadu_ps((float const *)(a10));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ \ ymm2 = _mm256_broadcast_ss(tptr + cs_b * 0 + 0);\ ymm3 = _mm256_broadcast_ss(tptr + cs_b * 0 + 1);\ \ ymm8 = _mm256_fmadd_ps(ymm0, ymm2, ymm8);\ - ymm12 = _mm256_fmadd_ps(ymm1, ymm2, ymm12);\ ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);\ - ymm5 = _mm256_fmadd_ps(ymm1, ymm3, ymm5);\ \ ymm2 = _mm256_broadcast_ss(tptr + cs_b * 1 * 2+ 0);\ ymm3 = _mm256_broadcast_ss(tptr + cs_b * 1 * 2 + 1);\ \ ymm9 = _mm256_fmadd_ps(ymm0, ymm2, ymm9);\ - ymm13 = _mm256_fmadd_ps(ymm1, ymm2, ymm13);\ ymm6 = _mm256_fmadd_ps(ymm0, ymm3, ymm6);\ - ymm7 = _mm256_fmadd_ps(ymm1, ymm3, ymm7);\ \ tptr += 2;\ a10 += p_lda;\ }\ }\ ymm4 = _mm256_permute_ps(ymm4, 0xb1);\ - ymm5 = _mm256_permute_ps(ymm5, 0xb1);\ ymm6 = _mm256_permute_ps(ymm6, 0xb1);\ - ymm7 = _mm256_permute_ps(ymm7, 0xb1);\ \ ymm8 = _mm256_addsub_ps(ymm8, ymm4);\ - ymm12 = _mm256_addsub_ps(ymm12, ymm5);\ ymm9 = _mm256_addsub_ps(ymm9, ymm6);\ - ymm13 = _mm256_addsub_ps(ymm13, ymm7);\ } -#define BLIS_CTRSM_SMALL_GEMM_4mx3n(a10,b01,cs_b,p_lda,k_iter) {\ +#define BLIS_CTRSM_SMALL_GEMM_2mx1n(a10,b01,cs_b,p_lda,k_iter) {\ float *tptr = (float *)b01;\ if(conjtransa) {\ ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0);\ for(k = 0; k< k_iter; k++) \ { \ - ymm0 = _mm256_loadu_ps((float const *)(a10)); \ + xmm5 = _mm_loadu_ps((float const *)(a10));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ + ymm0 = _mm256_mul_ps(ymm0, ymm18);\ + \ + ymm2 = _mm256_broadcast_ss(tptr + cs_b * 0 + 0);\ + ymm3 = _mm256_broadcast_ss(tptr + cs_b * 0 + 1);\ + \ + ymm8 = _mm256_fmadd_ps(ymm0, ymm2, ymm8);\ + ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);\ + \ + tptr += 2;\ + a10 += p_lda;\ + }\ + }\ + else {\ + for(k = 0; k< k_iter; k++) \ + { \ + xmm5 = _mm_loadu_ps((float const *)(a10));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ + \ + ymm2 = _mm256_broadcast_ss(tptr + cs_b * 0 + 0);\ + ymm3 = _mm256_broadcast_ss(tptr + cs_b * 0 + 1);\ + \ + ymm8 = _mm256_fmadd_ps(ymm0, ymm2, ymm8);\ + ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);\ + \ + tptr += 2;\ + a10 += p_lda;\ + }\ + }\ + ymm4 = _mm256_permute_ps(ymm4, 0xb1);\ + \ + ymm8 = _mm256_addsub_ps(ymm8, ymm4);\ +} + +#define BLIS_CTRSM_SMALL_GEMM_1mx3n(a10,b01,cs_b,p_lda,k_iter) {\ + float *tptr = (float *)b01;\ + if(conjtransa) {\ + ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0);\ + for(k = 0; k< k_iter; k++) \ + { \ + xmm0 = _mm_loadl_pi(xmm0,(__m64 *)(a10));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0);\ ymm0 = _mm256_mul_ps(ymm0, ymm18);\ \ ymm2 = _mm256_broadcast_ss(tptr + cs_b * 0 + 0);\ @@ -38323,7 +41657,8 @@ BLIS_INLINE void ctrsm_small_pack_diag_element else {\ for(k = 0; k< k_iter; k++) \ { \ - ymm0 = _mm256_loadu_ps((float const *)(a10)); \ + xmm0 = _mm_loadl_pi(xmm0,(__m64 *)(a10));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0);\ \ ymm2 = _mm256_broadcast_ss(tptr + cs_b * 0 + 0);\ ymm3 = _mm256_broadcast_ss(tptr + cs_b * 0 + 1);\ @@ -38356,6 +41691,102 @@ BLIS_INLINE void ctrsm_small_pack_diag_element ymm10 = _mm256_addsub_ps(ymm10, ymm11);\ } +#define BLIS_CTRSM_SMALL_GEMM_1mx2n(a10,b01,cs_b,p_lda,k_iter) {\ + float *tptr = (float *)b01;\ + if(conjtransa) {\ + ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0);\ + for(k = 0; k< k_iter; k++) \ + { \ + xmm0 = _mm_loadl_pi(xmm0,(__m64 *)(a10));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0);\ + ymm0 = _mm256_mul_ps(ymm0, ymm18);\ + \ + ymm2 = _mm256_broadcast_ss(tptr + cs_b * 0 + 0);\ + ymm3 = _mm256_broadcast_ss(tptr + cs_b * 0 + 1);\ + \ + ymm8 = _mm256_fmadd_ps(ymm0, ymm2, ymm8);\ + ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);\ + \ + ymm2 = _mm256_broadcast_ss(tptr + cs_b * 1 * 2+ 0);\ + ymm3 = _mm256_broadcast_ss(tptr + cs_b * 1 * 2 + 1);\ + \ + ymm9 = _mm256_fmadd_ps(ymm0, ymm2, ymm9);\ + ymm6 = _mm256_fmadd_ps(ymm0, ymm3, ymm6);\ + \ + tptr += 2;\ + a10 += p_lda;\ + }\ + }\ + else {\ + for(k = 0; k< k_iter; k++) \ + { \ + xmm0 = _mm_loadl_pi(xmm0,(__m64 *)(a10));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0);\ + \ + ymm2 = _mm256_broadcast_ss(tptr + cs_b * 0 + 0);\ + ymm3 = _mm256_broadcast_ss(tptr + cs_b * 0 + 1);\ + \ + ymm8 = _mm256_fmadd_ps(ymm0, ymm2, ymm8);\ + ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);\ + \ + ymm2 = _mm256_broadcast_ss(tptr + cs_b * 1 * 2+ 0);\ + ymm3 = _mm256_broadcast_ss(tptr + cs_b * 1 * 2 + 1);\ + \ + ymm9 = _mm256_fmadd_ps(ymm0, ymm2, ymm9);\ + ymm6 = _mm256_fmadd_ps(ymm0, ymm3, ymm6);\ + \ + tptr += 2;\ + a10 += p_lda;\ + }\ + }\ + ymm4 = _mm256_permute_ps(ymm4, 0xb1);\ + ymm6 = _mm256_permute_ps(ymm6, 0xb1);\ + \ + ymm8 = _mm256_addsub_ps(ymm8, ymm4);\ + ymm9 = _mm256_addsub_ps(ymm9, ymm6);\ +} + +#define BLIS_CTRSM_SMALL_GEMM_1mx1n(a10,b01,cs_b,p_lda,k_iter) {\ + float *tptr = (float *)b01;\ + if(conjtransa) {\ + ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0);\ + for(k = 0; k< k_iter; k++) \ + { \ + xmm0 = _mm_loadl_pi(xmm0,(__m64 *)(a10));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0);\ + ymm0 = _mm256_mul_ps(ymm0, ymm18);\ + \ + ymm2 = _mm256_broadcast_ss(tptr + cs_b * 0 + 0);\ + ymm3 = _mm256_broadcast_ss(tptr + cs_b * 0 + 1);\ + \ + ymm8 = _mm256_fmadd_ps(ymm0, ymm2, ymm8);\ + ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);\ + \ + tptr += 2;\ + a10 += p_lda;\ + }\ + }\ + else {\ + for(k = 0; k< k_iter; k++) \ + { \ + xmm0 = _mm_loadl_pi(xmm0,(__m64 *)(a10));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0);\ + \ + ymm2 = _mm256_broadcast_ss(tptr + cs_b * 0 + 0);\ + ymm3 = _mm256_broadcast_ss(tptr + cs_b * 0 + 1);\ + \ + ymm8 = _mm256_fmadd_ps(ymm0, ymm2, ymm8);\ + ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);\ + \ + tptr += 2;\ + a10 += p_lda;\ + }\ + }\ + ymm4 = _mm256_permute_ps(ymm4, 0xb1);\ + \ + ymm8 = _mm256_addsub_ps(ymm8, ymm4);\ +} + /** * Performs GEMM operation. * Eight elements of column in ymm0, ymm1 @@ -38452,9 +41883,9 @@ BLIS_INLINE void ctrsm_small_pack_diag_element #define BLIS_CTRSM_SMALL_NREG_TRANSPOSE_1x4(b11,cs_b,AlphaVal){\ ymm16 = _mm256_broadcast_ps(( __m128 const *)(&AlphaVal));\ ymm16 = _mm256_permute_ps(ymm16, 0x44);\ -\ + \ ymm0 = _mm256_loadu_ps((float const *)(b11));\ - ymm3 = _mm256_broadcast_ps((__m128 const *)&ones);\ + ymm3 = _mm256_broadcast_ps((__m128 const *)&ones_a);\ ymm3 = _mm256_permute_ps(ymm3, 0x44);\ /*in register transpose * ymm0,ymm1,ymm2 holds @@ -38504,7 +41935,7 @@ BLIS_INLINE void ctrsm_small_pack_diag_element \ ymm0 = _mm256_loadu_ps((float const *)(b11));\ ymm1 = _mm256_loadu_ps((float const *)(b11 + cs_b *1));\ - ymm3 = _mm256_broadcast_ps((__m128 const *)&ones);\ + ymm3 = _mm256_broadcast_ps((__m128 const *)&ones_a);\ ymm3 = _mm256_permute_ps(ymm3, 0x44);\ /*in register transpose * ymm0,ymm1,ymm2 holds @@ -38561,7 +41992,7 @@ BLIS_INLINE void ctrsm_small_pack_diag_element ymm0 = _mm256_loadu_ps((float const *)(b11));\ ymm1 = _mm256_loadu_ps((float const *)(b11 + cs_b *1));\ ymm2 = _mm256_loadu_ps((float const *)(b11 + cs_b *2));\ - ymm3 = _mm256_broadcast_ps((__m128 const *)&ones);\ + ymm3 = _mm256_broadcast_ps((__m128 const *)&ones_a);\ ymm3 = _mm256_permute_ps(ymm3, 0x44);\ /*in register transpose * ymm0,ymm1,ymm2 holds @@ -38624,7 +42055,7 @@ BLIS_INLINE void ctrsm_small_pack_diag_element ymm0 = _mm256_loadu_ps((float const *)(b11));\ ymm1 = _mm256_loadu_ps((float const *)(b11 + cs_b *1));\ ymm2 = _mm256_loadu_ps((float const *)(b11 + cs_b *2));\ - ymm3 = _mm256_broadcast_ps((__m128 const *)&ones);\ + ymm3 = _mm256_broadcast_ps((__m128 const *)&ones_a);\ ymm3 = _mm256_permute_ps(ymm3, 0x44);\ /*in register transpose * ymm0,ymm1,ymm2 holds @@ -38781,7 +42212,6 @@ BLIS_INLINE void ctrsm_small_pack_diag_element _mm256_storeu_ps((float *)(b11 + cs_b * 2 + 4), ymm2);\ } - BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB ( obj_t* AlphaObj, @@ -38816,13 +42246,17 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB dim_t i, j, k; //loop variables dim_t k_iter; //number of times GEMM to be performed - scomplex AlphaVal = *(scomplex *)AlphaObj->buffer; //value of alpha - scomplex *L = a->buffer; //pointer to matrix A - scomplex *B = b->buffer; //pointer to matrix B + scomplex AlphaVal[2]; + AlphaVal[0] = *(scomplex *)AlphaObj->buffer; //value of alpha + AlphaVal[1] = *(scomplex *)AlphaObj->buffer; //value of alpha + + scomplex *L = bli_obj_buffer_at_off(a); //pointer to matrix A + scomplex *B = bli_obj_buffer_at_off(b); //pointer to matrix B scomplex *a10, *a11, *b01, *b11; //pointers that point to blocks for GEMM and TRSM - scomplex ones = {1.0, 1.0}; + float ones = 1.0; + float ones_a[4] = {1.0, 1.0,1.0,1.0}; bool is_unitdiag = bli_obj_has_unit_diag(a); //scratch registers @@ -38833,8 +42267,16 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB __m256 ymm16, ymm17, ymm18, ymm19; __m128 xmm0, xmm1, xmm2, xmm3, xmm4; - - gint_t required_packing_A = 1; + __m128 xmm5; + + xmm0 = _mm_setzero_ps(); + xmm1 = _mm_setzero_ps(); + xmm2 = _mm_setzero_ps(); + xmm3 = _mm_setzero_ps(); + xmm4 = _mm_setzero_ps(); + xmm5 = _mm_setzero_ps(); + + gint_t required_packing_A = 1; mem_t local_mem_buf_A_s = {0}; scomplex *D_A_pack = NULL; scomplex d11_pack[d_mr] __attribute__((aligned(64))); @@ -38863,8 +42305,8 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB D_A_pack = bli_mem_buffer(&local_mem_buf_A_s); if(NULL==D_A_pack) return BLIS_NULL_POINTER; } - - /* + + /* Performs solving TRSM for 4 colmns at a time from 0 to m/4 in steps of d_mr a. Load, transpose, Pack A (a10 block), the size of packing 4x3 to 4x (m-4) First there will be no GEMM and no packing of a10 because it is only TRSM @@ -38949,14 +42391,20 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB */ ////extract a00 ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack)); + ymm1 = _mm256_set_ps((d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm8) #else BLIS_CTRSM_MUL(ymm8) #endif - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*1) ); + ymm2 = _mm256_set_ps((a11 + cs_a*1)->imag,(a11 + cs_a*1)->real, + (a11 + cs_a*1)->imag,(a11 + cs_a*1)->real, + (a11 + cs_a*1)->imag,(a11 + cs_a*1)->real, + (a11 + cs_a*1)->imag,(a11 + cs_a*1)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -38973,7 +42421,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm9 = _mm256_sub_ps(ymm9,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*2) ); + ymm2 = _mm256_set_ps((a11 + cs_a*2)->imag,(a11 + cs_a*2)->real, + (a11 + cs_a*2)->imag,(a11 + cs_a*2)->real, + (a11 + cs_a*2)->imag,(a11 + cs_a*2)->real, + (a11 + cs_a*2)->imag,(a11 + cs_a*2)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -38988,7 +42439,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm10 = _mm256_sub_ps(ymm10,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*3) ); + ymm2 = _mm256_set_ps((a11 + cs_a*3)->imag,(a11 + cs_a*3)->real, + (a11 + cs_a*3)->imag,(a11 + cs_a*3)->real, + (a11 + cs_a*3)->imag,(a11 + cs_a*3)->real, + (a11 + cs_a*3)->imag,(a11 + cs_a*3)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -39003,7 +42457,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm11 = _mm256_sub_ps(ymm11,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*4) ); + ymm2 = _mm256_set_ps((a11 + cs_a*4)->imag,(a11 + cs_a*4)->real, + (a11 + cs_a*4)->imag,(a11 + cs_a*4)->real, + (a11 + cs_a*4)->imag,(a11 + cs_a*4)->real, + (a11 + cs_a*4)->imag,(a11 + cs_a*4)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -39018,7 +42475,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm12 = _mm256_sub_ps(ymm12,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*5) ); + ymm2 = _mm256_set_ps((a11 + cs_a*5)->imag,(a11 + cs_a*5)->real, + (a11 + cs_a*5)->imag,(a11 + cs_a*5)->real, + (a11 + cs_a*5)->imag,(a11 + cs_a*5)->real, + (a11 + cs_a*5)->imag,(a11 + cs_a*5)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -39033,7 +42493,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm13 = _mm256_sub_ps(ymm13,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*6) ); + ymm2 = _mm256_set_ps((a11 + cs_a*6)->imag,(a11 + cs_a*6)->real, + (a11 + cs_a*6)->imag,(a11 + cs_a*6)->real, + (a11 + cs_a*6)->imag,(a11 + cs_a*6)->real, + (a11 + cs_a*6)->imag,(a11 + cs_a*6)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -39048,7 +42511,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm14 = _mm256_sub_ps(ymm14,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*7) ); + ymm2 = _mm256_set_ps((a11 + cs_a*7)->imag,(a11 + cs_a*7)->real, + (a11 + cs_a*7)->imag,(a11 + cs_a*7)->real, + (a11 + cs_a*7)->imag,(a11 + cs_a*7)->real, + (a11 + cs_a*7)->imag,(a11 + cs_a*7)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -39063,7 +42529,11 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm15 = _mm256_sub_ps(ymm15,ymm16); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 1)); + ymm1 = _mm256_set_ps((d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real); + ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm9) @@ -39074,7 +42544,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB a11 += rs_a; - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*2) ); + ymm2 = _mm256_set_ps((a11 + cs_a*2)->imag,(a11 + cs_a*2)->real, + (a11 + cs_a*2)->imag,(a11 + cs_a*2)->real, + (a11 + cs_a*2)->imag,(a11 + cs_a*2)->real, + (a11 + cs_a*2)->imag,(a11 + cs_a*2)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -39089,7 +42562,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm10 = _mm256_sub_ps(ymm10,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*3) ); + ymm2 = _mm256_set_ps((a11 + cs_a*3)->imag,(a11 + cs_a*3)->real, + (a11 + cs_a*3)->imag,(a11 + cs_a*3)->real, + (a11 + cs_a*3)->imag,(a11 + cs_a*3)->real, + (a11 + cs_a*3)->imag,(a11 + cs_a*3)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -39104,7 +42580,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm11 = _mm256_sub_ps(ymm11,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*4) ); + ymm2 = _mm256_set_ps((a11 + cs_a*4)->imag,(a11 + cs_a*4)->real, + (a11 + cs_a*4)->imag,(a11 + cs_a*4)->real, + (a11 + cs_a*4)->imag,(a11 + cs_a*4)->real, + (a11 + cs_a*4)->imag,(a11 + cs_a*4)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -39119,7 +42598,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm12 = _mm256_sub_ps(ymm12,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*5) ); + ymm2 = _mm256_set_ps((a11 + cs_a*5)->imag,(a11 + cs_a*5)->real, + (a11 + cs_a*5)->imag,(a11 + cs_a*5)->real, + (a11 + cs_a*5)->imag,(a11 + cs_a*5)->real, + (a11 + cs_a*5)->imag,(a11 + cs_a*5)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -39134,7 +42616,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm13 = _mm256_sub_ps(ymm13,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*6) ); + ymm2 = _mm256_set_ps((a11 + cs_a*6)->imag,(a11 + cs_a*6)->real, + (a11 + cs_a*6)->imag,(a11 + cs_a*6)->real, + (a11 + cs_a*6)->imag,(a11 + cs_a*6)->real, + (a11 + cs_a*6)->imag,(a11 + cs_a*6)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -39149,7 +42634,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm14 = _mm256_sub_ps(ymm14,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*7) ); + ymm2 = _mm256_set_ps((a11 + cs_a*7)->imag,(a11 + cs_a*7)->real, + (a11 + cs_a*7)->imag,(a11 + cs_a*7)->real, + (a11 + cs_a*7)->imag,(a11 + cs_a*7)->real, + (a11 + cs_a*7)->imag,(a11 + cs_a*7)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -39165,7 +42653,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB ymm15 = _mm256_sub_ps(ymm15,ymm16); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 2)); + ymm1 = _mm256_set_ps((d11_pack + 2)->imag,(d11_pack + 2)->real, + (d11_pack + 2)->imag,(d11_pack + 2)->real, + (d11_pack + 2)->imag,(d11_pack + 2)->real, + (d11_pack + 2)->imag,(d11_pack + 2)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION @@ -39177,7 +42668,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB a11 += rs_a; - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*3) ); + ymm2 = _mm256_set_ps((a11 + cs_a*3)->imag,(a11 + cs_a*3)->real, + (a11 + cs_a*3)->imag,(a11 + cs_a*3)->real, + (a11 + cs_a*3)->imag,(a11 + cs_a*3)->real, + (a11 + cs_a*3)->imag,(a11 + cs_a*3)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -39192,7 +42686,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm11 = _mm256_sub_ps(ymm11,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*4) ); + ymm2 = _mm256_set_ps((a11 + cs_a*4)->imag,(a11 + cs_a*4)->real, + (a11 + cs_a*4)->imag,(a11 + cs_a*4)->real, + (a11 + cs_a*4)->imag,(a11 + cs_a*4)->real, + (a11 + cs_a*4)->imag,(a11 + cs_a*4)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -39207,7 +42704,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm12 = _mm256_sub_ps(ymm12,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*5) ); + ymm2 = _mm256_set_ps((a11 + cs_a*5)->imag,(a11 + cs_a*5)->real, + (a11 + cs_a*5)->imag,(a11 + cs_a*5)->real, + (a11 + cs_a*5)->imag,(a11 + cs_a*5)->real, + (a11 + cs_a*5)->imag,(a11 + cs_a*5)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -39222,7 +42722,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm13 = _mm256_sub_ps(ymm13,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*6) ); + ymm2 = _mm256_set_ps((a11 + cs_a*6)->imag,(a11 + cs_a*6)->real, + (a11 + cs_a*6)->imag,(a11 + cs_a*6)->real, + (a11 + cs_a*6)->imag,(a11 + cs_a*6)->real, + (a11 + cs_a*6)->imag,(a11 + cs_a*6)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -39237,7 +42740,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm14 = _mm256_sub_ps(ymm14,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*7) ); + ymm2 = _mm256_set_ps((a11 + cs_a*7)->imag,(a11 + cs_a*7)->real, + (a11 + cs_a*7)->imag,(a11 + cs_a*7)->real, + (a11 + cs_a*7)->imag,(a11 + cs_a*7)->real, + (a11 + cs_a*7)->imag,(a11 + cs_a*7)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -39253,7 +42759,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB ymm15 = _mm256_sub_ps(ymm15,ymm16); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 3)); + ymm1 = _mm256_set_ps((d11_pack + 3)->imag,(d11_pack + 3)->real, + (d11_pack + 3)->imag,(d11_pack + 3)->real, + (d11_pack + 3)->imag,(d11_pack + 3)->real, + (d11_pack + 3)->imag,(d11_pack + 3)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm11) @@ -39264,7 +42773,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB a11 += rs_a; - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*4) ); + ymm2 = _mm256_set_ps((a11 + cs_a*4)->imag,(a11 + cs_a*4)->real, + (a11 + cs_a*4)->imag,(a11 + cs_a*4)->real, + (a11 + cs_a*4)->imag,(a11 + cs_a*4)->real, + (a11 + cs_a*4)->imag,(a11 + cs_a*4)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -39279,7 +42791,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm12 = _mm256_sub_ps(ymm12,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*5) ); + ymm2 = _mm256_set_ps((a11 + cs_a*5)->imag,(a11 + cs_a*5)->real, + (a11 + cs_a*5)->imag,(a11 + cs_a*5)->real, + (a11 + cs_a*5)->imag,(a11 + cs_a*5)->real, + (a11 + cs_a*5)->imag,(a11 + cs_a*5)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -39294,7 +42809,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm13 = _mm256_sub_ps(ymm13,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*6) ); + ymm2 = _mm256_set_ps((a11 + cs_a*6)->imag,(a11 + cs_a*6)->real, + (a11 + cs_a*6)->imag,(a11 + cs_a*6)->real, + (a11 + cs_a*6)->imag,(a11 + cs_a*6)->real, + (a11 + cs_a*6)->imag,(a11 + cs_a*6)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -39309,7 +42827,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm14 = _mm256_sub_ps(ymm14,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*7) ); + ymm2 = _mm256_set_ps((a11 + cs_a*7)->imag,(a11 + cs_a*7)->real, + (a11 + cs_a*7)->imag,(a11 + cs_a*7)->real, + (a11 + cs_a*7)->imag,(a11 + cs_a*7)->real, + (a11 + cs_a*7)->imag,(a11 + cs_a*7)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -39325,7 +42846,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB ymm15 = _mm256_sub_ps(ymm15,ymm16); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 4)); + ymm1 = _mm256_set_ps((d11_pack + 4)->imag,(d11_pack + 4)->real, + (d11_pack + 4)->imag,(d11_pack + 4)->real, + (d11_pack + 4)->imag,(d11_pack + 4)->real, + (d11_pack + 4)->imag,(d11_pack + 4)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm12) @@ -39334,7 +42858,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB #endif a11 += rs_a; - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*5) ); + ymm2 = _mm256_set_ps((a11 + cs_a*5)->imag,(a11 + cs_a*5)->real, + (a11 + cs_a*5)->imag,(a11 + cs_a*5)->real, + (a11 + cs_a*5)->imag,(a11 + cs_a*5)->real, + (a11 + cs_a*5)->imag,(a11 + cs_a*5)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -39349,7 +42876,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm13 = _mm256_sub_ps(ymm13,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*6) ); + ymm2 = _mm256_set_ps((a11 + cs_a*6)->imag,(a11 + cs_a*6)->real, + (a11 + cs_a*6)->imag,(a11 + cs_a*6)->real, + (a11 + cs_a*6)->imag,(a11 + cs_a*6)->real, + (a11 + cs_a*6)->imag,(a11 + cs_a*6)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -39364,7 +42894,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm14 = _mm256_sub_ps(ymm14,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*7) ); + ymm2 = _mm256_set_ps((a11 + cs_a*7)->imag,(a11 + cs_a*7)->real, + (a11 + cs_a*7)->imag,(a11 + cs_a*7)->real, + (a11 + cs_a*7)->imag,(a11 + cs_a*7)->real, + (a11 + cs_a*7)->imag,(a11 + cs_a*7)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -39380,7 +42913,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB ymm15 = _mm256_sub_ps(ymm15,ymm16); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 5)); + ymm1 = _mm256_set_ps((d11_pack + 5)->imag,(d11_pack + 5)->real, + (d11_pack + 5)->imag,(d11_pack + 5)->real, + (d11_pack + 5)->imag,(d11_pack + 5)->real, + (d11_pack + 5)->imag,(d11_pack + 5)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm13) @@ -39390,7 +42926,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB a11 += rs_a; - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*6) ); + ymm2 = _mm256_set_ps((a11 + cs_a*6)->imag,(a11 + cs_a*6)->real, + (a11 + cs_a*6)->imag,(a11 + cs_a*6)->real, + (a11 + cs_a*6)->imag,(a11 + cs_a*6)->real, + (a11 + cs_a*6)->imag,(a11 + cs_a*6)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -39405,7 +42944,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm14 = _mm256_sub_ps(ymm14,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*7) ); + ymm2 = _mm256_set_ps((a11 + cs_a*7)->imag,(a11 + cs_a*7)->real, + (a11 + cs_a*7)->imag,(a11 + cs_a*7)->real, + (a11 + cs_a*7)->imag,(a11 + cs_a*7)->real, + (a11 + cs_a*7)->imag,(a11 + cs_a*7)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -39421,7 +42963,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB ymm15 = _mm256_sub_ps(ymm15,ymm16); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 6)); + ymm1 = _mm256_set_ps((d11_pack + 6)->imag,(d11_pack + 6)->real, + (d11_pack + 6)->imag,(d11_pack + 6)->real, + (d11_pack + 6)->imag,(d11_pack + 6)->real, + (d11_pack + 6)->imag,(d11_pack + 6)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm14) @@ -39431,7 +42976,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB a11 += rs_a; - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*7) ); + ymm2 = _mm256_set_ps((a11 + cs_a*7)->imag,(a11 + cs_a*7)->real, + (a11 + cs_a*7)->imag,(a11 + cs_a*7)->real, + (a11 + cs_a*7)->imag,(a11 + cs_a*7)->real, + (a11 + cs_a*7)->imag,(a11 + cs_a*7)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -39447,7 +42995,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB ymm15 = _mm256_sub_ps(ymm15,ymm16); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 7)); + ymm1 = _mm256_set_ps((d11_pack + 7)->imag,(d11_pack + 7)->real, + (d11_pack + 7)->imag,(d11_pack + 7)->real, + (d11_pack + 7)->imag,(d11_pack + 7)->real, + (d11_pack + 7)->imag,(d11_pack + 7)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm15) @@ -39476,8 +43027,8 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB BLIS_CTRSM_SMALL_GEMM_8mx2n(a10,b01,cs_b,p_lda,k_iter) float zero = 0.0; - ymm16 = _mm256_broadcast_ss(&AlphaVal.real); - ymm17 = _mm256_broadcast_ss(&AlphaVal.imag); + ymm16 = _mm256_broadcast_ss(&AlphaVal[0].real); + ymm17 = _mm256_broadcast_ss(&AlphaVal[0].imag); ymm2 = _mm256_broadcast_ss(&zero); ymm3 = _mm256_broadcast_ss(&zero); ymm6 = _mm256_broadcast_ss(&zero); @@ -39528,8 +43079,8 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB BLIS_CTRSM_SMALL_GEMM_8mx1n(a10,b01,cs_b,p_lda,k_iter) float zero = 0.0; - ymm16 = _mm256_broadcast_ss(&AlphaVal.real); - ymm17 = _mm256_broadcast_ss(&AlphaVal.imag); + ymm16 = _mm256_broadcast_ss(&AlphaVal[0].real); + ymm17 = _mm256_broadcast_ss(&AlphaVal[0].imag); ymm2 = _mm256_broadcast_ss(&zero); ymm3 = _mm256_broadcast_ss(&zero); ymm6 = _mm256_broadcast_ss(&zero); @@ -39597,14 +43148,20 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB ymm15 = _mm256_permute2f128_ps(ymm18,ymm19,0x31); ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack)); + ymm1 = _mm256_set_ps((d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm8) #else BLIS_CTRSM_MUL(ymm8) #endif - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*1) ); + ymm2 = _mm256_set_ps((a11 + cs_a*1)->imag,(a11 + cs_a*1)->real, + (a11 + cs_a*1)->imag,(a11 + cs_a*1)->real, + (a11 + cs_a*1)->imag,(a11 + cs_a*1)->real, + (a11 + cs_a*1)->imag,(a11 + cs_a*1)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -39621,7 +43178,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm9 = _mm256_sub_ps(ymm9,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*2) ); + ymm2 = _mm256_set_ps((a11 + cs_a*2)->imag,(a11 + cs_a*2)->real, + (a11 + cs_a*2)->imag,(a11 + cs_a*2)->real, + (a11 + cs_a*2)->imag,(a11 + cs_a*2)->real, + (a11 + cs_a*2)->imag,(a11 + cs_a*2)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -39636,7 +43196,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm10 = _mm256_sub_ps(ymm10,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*3) ); + ymm2 = _mm256_set_ps((a11 + cs_a*3)->imag,(a11 + cs_a*3)->real, + (a11 + cs_a*3)->imag,(a11 + cs_a*3)->real, + (a11 + cs_a*3)->imag,(a11 + cs_a*3)->real, + (a11 + cs_a*3)->imag,(a11 + cs_a*3)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -39651,7 +43214,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm11 = _mm256_sub_ps(ymm11,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*4) ); + ymm2 = _mm256_set_ps((a11 + cs_a*4)->imag,(a11 + cs_a*4)->real, + (a11 + cs_a*4)->imag,(a11 + cs_a*4)->real, + (a11 + cs_a*4)->imag,(a11 + cs_a*4)->real, + (a11 + cs_a*4)->imag,(a11 + cs_a*4)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -39666,7 +43232,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm12 = _mm256_sub_ps(ymm12,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*5) ); + ymm2 = _mm256_set_ps((a11 + cs_a*5)->imag,(a11 + cs_a*5)->real, + (a11 + cs_a*5)->imag,(a11 + cs_a*5)->real, + (a11 + cs_a*5)->imag,(a11 + cs_a*5)->real, + (a11 + cs_a*5)->imag,(a11 + cs_a*5)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -39681,7 +43250,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm13 = _mm256_sub_ps(ymm13,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*6) ); + ymm2 = _mm256_set_ps((a11 + cs_a*6)->imag,(a11 + cs_a*6)->real, + (a11 + cs_a*6)->imag,(a11 + cs_a*6)->real, + (a11 + cs_a*6)->imag,(a11 + cs_a*6)->real, + (a11 + cs_a*6)->imag,(a11 + cs_a*6)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -39696,7 +43268,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm14 = _mm256_sub_ps(ymm14,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*7) ); + ymm2 = _mm256_set_ps((a11 + cs_a*7)->imag,(a11 + cs_a*7)->real, + (a11 + cs_a*7)->imag,(a11 + cs_a*7)->real, + (a11 + cs_a*7)->imag,(a11 + cs_a*7)->real, + (a11 + cs_a*7)->imag,(a11 + cs_a*7)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -39711,7 +43286,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm15 = _mm256_sub_ps(ymm15,ymm16); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 1)); + ymm1 = _mm256_set_ps((d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm9) @@ -39722,7 +43300,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB a11 += rs_a; - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*2) ); + ymm2 = _mm256_set_ps((a11 + cs_a*2)->imag,(a11 + cs_a*2)->real, + (a11 + cs_a*2)->imag,(a11 + cs_a*2)->real, + (a11 + cs_a*2)->imag,(a11 + cs_a*2)->real, + (a11 + cs_a*2)->imag,(a11 + cs_a*2)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -39737,7 +43318,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm10 = _mm256_sub_ps(ymm10,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*3) ); + ymm2 = _mm256_set_ps((a11 + cs_a*3)->imag,(a11 + cs_a*3)->real, + (a11 + cs_a*3)->imag,(a11 + cs_a*3)->real, + (a11 + cs_a*3)->imag,(a11 + cs_a*3)->real, + (a11 + cs_a*3)->imag,(a11 + cs_a*3)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -39752,7 +43336,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm11 = _mm256_sub_ps(ymm11,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*4) ); + ymm2 = _mm256_set_ps((a11 + cs_a*4)->imag,(a11 + cs_a*4)->real, + (a11 + cs_a*4)->imag,(a11 + cs_a*4)->real, + (a11 + cs_a*4)->imag,(a11 + cs_a*4)->real, + (a11 + cs_a*4)->imag,(a11 + cs_a*4)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -39767,7 +43354,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm12 = _mm256_sub_ps(ymm12,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*5) ); + ymm2 = _mm256_set_ps((a11 + cs_a*5)->imag,(a11 + cs_a*5)->real, + (a11 + cs_a*5)->imag,(a11 + cs_a*5)->real, + (a11 + cs_a*5)->imag,(a11 + cs_a*5)->real, + (a11 + cs_a*5)->imag,(a11 + cs_a*5)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -39782,7 +43372,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm13 = _mm256_sub_ps(ymm13,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*6) ); + ymm2 = _mm256_set_ps((a11 + cs_a*6)->imag,(a11 + cs_a*6)->real, + (a11 + cs_a*6)->imag,(a11 + cs_a*6)->real, + (a11 + cs_a*6)->imag,(a11 + cs_a*6)->real, + (a11 + cs_a*6)->imag,(a11 + cs_a*6)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -39797,7 +43390,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm14 = _mm256_sub_ps(ymm14,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*7) ); + ymm2 = _mm256_set_ps((a11 + cs_a*7)->imag,(a11 + cs_a*7)->real, + (a11 + cs_a*7)->imag,(a11 + cs_a*7)->real, + (a11 + cs_a*7)->imag,(a11 + cs_a*7)->real, + (a11 + cs_a*7)->imag,(a11 + cs_a*7)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -39813,7 +43409,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB ymm15 = _mm256_sub_ps(ymm15,ymm16); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 2)); + ymm1 = _mm256_set_ps((d11_pack + 2)->imag,(d11_pack + 2)->real, + (d11_pack + 2)->imag,(d11_pack + 2)->real, + (d11_pack + 2)->imag,(d11_pack + 2)->real, + (d11_pack + 2)->imag,(d11_pack + 2)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION @@ -39825,7 +43424,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB a11 += rs_a; - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*3) ); + ymm2 = _mm256_set_ps((a11 + cs_a*3)->imag,(a11 + cs_a*3)->real, + (a11 + cs_a*3)->imag,(a11 + cs_a*3)->real, + (a11 + cs_a*3)->imag,(a11 + cs_a*3)->real, + (a11 + cs_a*3)->imag,(a11 + cs_a*3)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -39840,7 +43442,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm11 = _mm256_sub_ps(ymm11,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*4) ); + ymm2 = _mm256_set_ps((a11 + cs_a*4)->imag,(a11 + cs_a*4)->real, + (a11 + cs_a*4)->imag,(a11 + cs_a*4)->real, + (a11 + cs_a*4)->imag,(a11 + cs_a*4)->real, + (a11 + cs_a*4)->imag,(a11 + cs_a*4)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -39855,7 +43460,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm12 = _mm256_sub_ps(ymm12,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*5) ); + ymm2 = _mm256_set_ps((a11 + cs_a*5)->imag,(a11 + cs_a*5)->real, + (a11 + cs_a*5)->imag,(a11 + cs_a*5)->real, + (a11 + cs_a*5)->imag,(a11 + cs_a*5)->real, + (a11 + cs_a*5)->imag,(a11 + cs_a*5)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -39870,7 +43478,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm13 = _mm256_sub_ps(ymm13,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*6) ); + ymm2 = _mm256_set_ps((a11 + cs_a*6)->imag,(a11 + cs_a*6)->real, + (a11 + cs_a*6)->imag,(a11 + cs_a*6)->real, + (a11 + cs_a*6)->imag,(a11 + cs_a*6)->real, + (a11 + cs_a*6)->imag,(a11 + cs_a*6)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -39885,7 +43496,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm14 = _mm256_sub_ps(ymm14,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*7) ); + ymm2 = _mm256_set_ps((a11 + cs_a*7)->imag,(a11 + cs_a*7)->real, + (a11 + cs_a*7)->imag,(a11 + cs_a*7)->real, + (a11 + cs_a*7)->imag,(a11 + cs_a*7)->real, + (a11 + cs_a*7)->imag,(a11 + cs_a*7)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -39901,7 +43515,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB ymm15 = _mm256_sub_ps(ymm15,ymm16); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 3)); + ymm1 = _mm256_set_ps((d11_pack + 3)->imag,(d11_pack + 3)->real, + (d11_pack + 3)->imag,(d11_pack + 3)->real, + (d11_pack + 3)->imag,(d11_pack + 3)->real, + (d11_pack + 3)->imag,(d11_pack + 3)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm11) @@ -39912,7 +43529,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB a11 += rs_a; - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*4) ); + ymm2 = _mm256_set_ps((a11 + cs_a*4)->imag,(a11 + cs_a*4)->real, + (a11 + cs_a*4)->imag,(a11 + cs_a*4)->real, + (a11 + cs_a*4)->imag,(a11 + cs_a*4)->real, + (a11 + cs_a*4)->imag,(a11 + cs_a*4)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -39927,7 +43547,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm12 = _mm256_sub_ps(ymm12,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*5) ); + ymm2 = _mm256_set_ps((a11 + cs_a*5)->imag,(a11 + cs_a*5)->real, + (a11 + cs_a*5)->imag,(a11 + cs_a*5)->real, + (a11 + cs_a*5)->imag,(a11 + cs_a*5)->real, + (a11 + cs_a*5)->imag,(a11 + cs_a*5)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -39942,7 +43565,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm13 = _mm256_sub_ps(ymm13,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*6) ); + ymm2 = _mm256_set_ps((a11 + cs_a*6)->imag,(a11 + cs_a*6)->real, + (a11 + cs_a*6)->imag,(a11 + cs_a*6)->real, + (a11 + cs_a*6)->imag,(a11 + cs_a*6)->real, + (a11 + cs_a*6)->imag,(a11 + cs_a*6)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -39957,7 +43583,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm14 = _mm256_sub_ps(ymm14,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*7) ); + ymm2 = _mm256_set_ps((a11 + cs_a*7)->imag,(a11 + cs_a*7)->real, + (a11 + cs_a*7)->imag,(a11 + cs_a*7)->real, + (a11 + cs_a*7)->imag,(a11 + cs_a*7)->real, + (a11 + cs_a*7)->imag,(a11 + cs_a*7)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -39973,7 +43602,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB ymm15 = _mm256_sub_ps(ymm15,ymm16); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 4)); + ymm1 = _mm256_set_ps((d11_pack + 4)->imag,(d11_pack + 4)->real, + (d11_pack + 4)->imag,(d11_pack + 4)->real, + (d11_pack + 4)->imag,(d11_pack + 4)->real, + (d11_pack + 4)->imag,(d11_pack + 4)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm12) @@ -39982,7 +43614,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB #endif a11 += rs_a; - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*5) ); + ymm2 = _mm256_set_ps((a11 + cs_a*5)->imag,(a11 + cs_a*5)->real, + (a11 + cs_a*5)->imag,(a11 + cs_a*5)->real, + (a11 + cs_a*5)->imag,(a11 + cs_a*5)->real, + (a11 + cs_a*5)->imag,(a11 + cs_a*5)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -39997,7 +43632,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm13 = _mm256_sub_ps(ymm13,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*6) ); + ymm2 = _mm256_set_ps((a11 + cs_a*6)->imag,(a11 + cs_a*6)->real, + (a11 + cs_a*6)->imag,(a11 + cs_a*6)->real, + (a11 + cs_a*6)->imag,(a11 + cs_a*6)->real, + (a11 + cs_a*6)->imag,(a11 + cs_a*6)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -40012,7 +43650,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm14 = _mm256_sub_ps(ymm14,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*7) ); + ymm2 = _mm256_set_ps((a11 + cs_a*7)->imag,(a11 + cs_a*7)->real, + (a11 + cs_a*7)->imag,(a11 + cs_a*7)->real, + (a11 + cs_a*7)->imag,(a11 + cs_a*7)->real, + (a11 + cs_a*7)->imag,(a11 + cs_a*7)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -40028,7 +43669,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB ymm15 = _mm256_sub_ps(ymm15,ymm16); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 5)); + ymm1 = _mm256_set_ps((d11_pack + 5)->imag,(d11_pack + 5)->real, + (d11_pack + 5)->imag,(d11_pack + 5)->real, + (d11_pack + 5)->imag,(d11_pack + 5)->real, + (d11_pack + 5)->imag,(d11_pack + 5)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm13) @@ -40038,7 +43682,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB a11 += rs_a; - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*6) ); + ymm2 = _mm256_set_ps((a11 + cs_a*6)->imag,(a11 + cs_a*6)->real, + (a11 + cs_a*6)->imag,(a11 + cs_a*6)->real, + (a11 + cs_a*6)->imag,(a11 + cs_a*6)->real, + (a11 + cs_a*6)->imag,(a11 + cs_a*6)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -40053,7 +43700,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm14 = _mm256_sub_ps(ymm14,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*7) ); + ymm2 = _mm256_set_ps((a11 + cs_a*7)->imag,(a11 + cs_a*7)->real, + (a11 + cs_a*7)->imag,(a11 + cs_a*7)->real, + (a11 + cs_a*7)->imag,(a11 + cs_a*7)->real, + (a11 + cs_a*7)->imag,(a11 + cs_a*7)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -40069,7 +43719,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB ymm15 = _mm256_sub_ps(ymm15,ymm16); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 6)); + ymm1 = _mm256_set_ps((d11_pack + 6)->imag,(d11_pack + 6)->real, + (d11_pack + 6)->imag,(d11_pack + 6)->real, + (d11_pack + 6)->imag,(d11_pack + 6)->real, + (d11_pack + 6)->imag,(d11_pack + 6)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm14) @@ -40079,7 +43732,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB a11 += rs_a; - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*7) ); + ymm2 = _mm256_set_ps((a11 + cs_a*7)->imag,(a11 + cs_a*7)->real, + (a11 + cs_a*7)->imag,(a11 + cs_a*7)->real, + (a11 + cs_a*7)->imag,(a11 + cs_a*7)->real, + (a11 + cs_a*7)->imag,(a11 + cs_a*7)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -40095,7 +43751,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB ymm15 = _mm256_sub_ps(ymm15,ymm16); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 7)); + ymm1 = _mm256_set_ps((d11_pack + 7)->imag,(d11_pack + 7)->real, + (d11_pack + 7)->imag,(d11_pack + 7)->real, + (d11_pack + 7)->imag,(d11_pack + 7)->real, + (d11_pack + 7)->imag,(d11_pack + 7)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm15) @@ -40211,14 +43870,20 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB BLIS_CTRSM_SMALL_NREG_TRANSPOSE_3x4(b11,cs_b,AlphaVal) ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack)); + ymm1 = _mm256_set_ps((d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm8) #else BLIS_CTRSM_MUL(ymm8) #endif - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*1) ); + ymm2 = _mm256_set_ps((a11 + cs_a*1)->imag,(a11 + cs_a*1)->real, + (a11 + cs_a*1)->imag,(a11 + cs_a*1)->real, + (a11 + cs_a*1)->imag,(a11 + cs_a*1)->real, + (a11 + cs_a*1)->imag,(a11 + cs_a*1)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -40234,7 +43899,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm9 = _mm256_sub_ps(ymm9,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*2) ); + ymm2 = _mm256_set_ps((a11 + cs_a*2)->imag,(a11 + cs_a*2)->real, + (a11 + cs_a*2)->imag,(a11 + cs_a*2)->real, + (a11 + cs_a*2)->imag,(a11 + cs_a*2)->real, + (a11 + cs_a*2)->imag,(a11 + cs_a*2)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -40248,7 +43916,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm10 = _mm256_sub_ps(ymm10,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*3) ); + ymm2 = _mm256_set_ps((a11 + cs_a*3)->imag,(a11 + cs_a*3)->real, + (a11 + cs_a*3)->imag,(a11 + cs_a*3)->real, + (a11 + cs_a*3)->imag,(a11 + cs_a*3)->real, + (a11 + cs_a*3)->imag,(a11 + cs_a*3)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -40263,7 +43934,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB ymm11 = _mm256_sub_ps(ymm11,ymm16); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 1)); + ymm1 = _mm256_set_ps((d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm9) @@ -40274,7 +43948,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB a11 += rs_a; - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*2) ); + ymm2 = _mm256_set_ps((a11 + cs_a*2)->imag,(a11 + cs_a*2)->real, + (a11 + cs_a*2)->imag,(a11 + cs_a*2)->real, + (a11 + cs_a*2)->imag,(a11 + cs_a*2)->real, + (a11 + cs_a*2)->imag,(a11 + cs_a*2)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -40288,7 +43965,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm10 = _mm256_sub_ps(ymm10,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*3) ); + ymm2 = _mm256_set_ps((a11 + cs_a*3)->imag,(a11 + cs_a*3)->real, + (a11 + cs_a*3)->imag,(a11 + cs_a*3)->real, + (a11 + cs_a*3)->imag,(a11 + cs_a*3)->real, + (a11 + cs_a*3)->imag,(a11 + cs_a*3)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -40304,7 +43984,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 2)); + ymm1 = _mm256_set_ps((d11_pack + 2)->imag,(d11_pack + 2)->real, + (d11_pack + 2)->imag,(d11_pack + 2)->real, + (d11_pack + 2)->imag,(d11_pack + 2)->real, + (d11_pack + 2)->imag,(d11_pack + 2)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION @@ -40316,7 +43999,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB a11 += rs_a; - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*3) ); + ymm2 = _mm256_set_ps((a11 + cs_a*3)->imag,(a11 + cs_a*3)->real, + (a11 + cs_a*3)->imag,(a11 + cs_a*3)->real, + (a11 + cs_a*3)->imag,(a11 + cs_a*3)->real, + (a11 + cs_a*3)->imag,(a11 + cs_a*3)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -40332,7 +44018,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 3)); + ymm1 = _mm256_set_ps((d11_pack + 3)->imag,(d11_pack + 3)->real, + (d11_pack + 3)->imag,(d11_pack + 3)->real, + (d11_pack + 3)->imag,(d11_pack + 3)->real, + (d11_pack + 3)->imag,(d11_pack + 3)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm11) @@ -40360,7 +44049,7 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB ///GEMM code begins/// BLIS_CTRSM_SMALL_GEMM_4mx2n(a10,b01,cs_b,p_lda,k_iter) BLIS_CTRSM_SMALL_NREG_TRANSPOSE_2x4(b11,cs_b,AlphaVal) - } + } else if(1 == n_rem) { ///GEMM code begins/// @@ -40370,14 +44059,20 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack)); + ymm1 = _mm256_set_ps((d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm8) #else BLIS_CTRSM_MUL(ymm8) #endif - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*1) ); + ymm2 = _mm256_set_ps((a11 + cs_a*1)->imag,(a11 + cs_a*1)->real, + (a11 + cs_a*1)->imag,(a11 + cs_a*1)->real, + (a11 + cs_a*1)->imag,(a11 + cs_a*1)->real, + (a11 + cs_a*1)->imag,(a11 + cs_a*1)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -40393,7 +44088,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm9 = _mm256_sub_ps(ymm9,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*2) ); + ymm2 = _mm256_set_ps((a11 + cs_a*2)->imag,(a11 + cs_a*2)->real, + (a11 + cs_a*2)->imag,(a11 + cs_a*2)->real, + (a11 + cs_a*2)->imag,(a11 + cs_a*2)->real, + (a11 + cs_a*2)->imag,(a11 + cs_a*2)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -40407,7 +44105,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm10 = _mm256_sub_ps(ymm10,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*3) ); + ymm2 = _mm256_set_ps((a11 + cs_a*3)->imag,(a11 + cs_a*3)->real, + (a11 + cs_a*3)->imag,(a11 + cs_a*3)->real, + (a11 + cs_a*3)->imag,(a11 + cs_a*3)->real, + (a11 + cs_a*3)->imag,(a11 + cs_a*3)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -40422,7 +44123,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB ymm11 = _mm256_sub_ps(ymm11,ymm16); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 1)); + ymm1 = _mm256_set_ps((d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm9) @@ -40433,7 +44137,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB a11 += rs_a; - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*2) ); + ymm2 = _mm256_set_ps((a11 + cs_a*2)->imag,(a11 + cs_a*2)->real, + (a11 + cs_a*2)->imag,(a11 + cs_a*2)->real, + (a11 + cs_a*2)->imag,(a11 + cs_a*2)->real, + (a11 + cs_a*2)->imag,(a11 + cs_a*2)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -40447,7 +44154,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm10 = _mm256_sub_ps(ymm10,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*3) ); + ymm2 = _mm256_set_ps((a11 + cs_a*3)->imag,(a11 + cs_a*3)->real, + (a11 + cs_a*3)->imag,(a11 + cs_a*3)->real, + (a11 + cs_a*3)->imag,(a11 + cs_a*3)->real, + (a11 + cs_a*3)->imag,(a11 + cs_a*3)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -40463,7 +44173,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 2)); + ymm1 = _mm256_set_ps((d11_pack + 2)->imag,(d11_pack + 2)->real, + (d11_pack + 2)->imag,(d11_pack + 2)->real, + (d11_pack + 2)->imag,(d11_pack + 2)->real, + (d11_pack + 2)->imag,(d11_pack + 2)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION @@ -40475,7 +44188,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB a11 += rs_a; - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*3) ); + ymm2 = _mm256_set_ps((a11 + cs_a*3)->imag,(a11 + cs_a*3)->real, + (a11 + cs_a*3)->imag,(a11 + cs_a*3)->real, + (a11 + cs_a*3)->imag,(a11 + cs_a*3)->real, + (a11 + cs_a*3)->imag,(a11 + cs_a*3)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -40491,7 +44207,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 3)); + ymm1 = _mm256_set_ps((d11_pack + 3)->imag,(d11_pack + 3)->real, + (d11_pack + 3)->imag,(d11_pack + 3)->real, + (d11_pack + 3)->imag,(d11_pack + 3)->real, + (d11_pack + 3)->imag,(d11_pack + 3)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm11) @@ -40541,8 +44260,7 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB (float const *)(a10 + cs_a)); ymm2 = _mm256_loadu_ps( (float const *)(a10 + cs_a * 2)); - ymm3 = _mm256_loadu_ps( - (float const *)(a10 + cs_a * 3)); + ymm3 = _mm256_broadcast_ss((float const *)&ones); ymm4 = _mm256_shuffle_ps(ymm0, ymm1, 0x44); ymm5 = _mm256_shuffle_ps(ymm2, ymm3, 0x44); @@ -40588,14 +44306,23 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB k_iter = i; BLIS_SET_S_YMM_REG_ZEROS - BLIS_CTRSM_SMALL_GEMM_4mx3n(a10,b01,cs_b,p_lda,k_iter) + BLIS_CTRSM_SMALL_GEMM_3mx3n(a10,b01,cs_b,p_lda,k_iter) ymm16 = _mm256_broadcast_ps(( __m128 const *)(&AlphaVal)); ymm16 = _mm256_permute_ps(ymm16, 0x44); - ymm0 = _mm256_loadu_ps((float const *)(b11)); - ymm1 = _mm256_loadu_ps((float const *)(b11 + cs_b *1)); - ymm2 = _mm256_loadu_ps((float const *)(b11 + cs_b *2)); + xmm0 = _mm_loadu_ps((float const *)(b11)); + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0); + xmm0 = _mm_loadl_pi(xmm0,(__m64 *)(b11 + 2)); + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 1); + xmm0 = _mm_loadu_ps((float const *)(b11 + cs_b)); + ymm1 = _mm256_insertf128_ps(ymm1, xmm0, 0); + xmm0 = _mm_loadl_pi(xmm0,(__m64 *)(b11 + cs_b + 2)); + ymm1 = _mm256_insertf128_ps(ymm1, xmm0, 1); + xmm0 = _mm_loadu_ps((float const *)(b11 + cs_b*2)); + ymm2 = _mm256_insertf128_ps(ymm2, xmm0, 0); + xmm0 = _mm_loadl_pi(xmm0,(__m64 *)(b11 + cs_b*2 + 2)); + ymm2 = _mm256_insertf128_ps(ymm2, xmm0, 1); ymm17 = _mm256_shuffle_ps(ymm16, ymm16, 0x11); @@ -40658,7 +44385,7 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB if(2 == n_rem) { ///GEMM code begins/// - BLIS_CTRSM_SMALL_GEMM_4mx2n(a10,b01,cs_b, + BLIS_CTRSM_SMALL_GEMM_3mx2n(a10,b01,cs_b, p_lda,k_iter) BLIS_PRE_CTRSM_SMALL_3M_2N(AlphaVal,b11,cs_b) @@ -40675,7 +44402,7 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB else if(1 == n_rem) { ///GEMM code begins/// - BLIS_CTRSM_SMALL_GEMM_4mx1n(a10,b01,cs_b, + BLIS_CTRSM_SMALL_GEMM_3mx1n(a10,b01,cs_b, p_lda,k_iter) BLIS_PRE_CTRSM_SMALL_3M_1N(AlphaVal,b11,cs_b) @@ -40749,15 +44476,17 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB k_iter = i; BLIS_SET_S_YMM_REG_ZEROS - BLIS_CTRSM_SMALL_GEMM_4mx3n(a10,b01,cs_b,p_lda,k_iter) + BLIS_CTRSM_SMALL_GEMM_2mx3n(a10,b01,cs_b,p_lda,k_iter) ymm16 = _mm256_broadcast_ps(( __m128 const *)(&AlphaVal)); ymm16 = _mm256_permute_ps(ymm16, 0x44); - ymm0 = _mm256_loadu_ps((float const *)(b11)); - ymm1 = _mm256_loadu_ps((float const *)(b11 + cs_b *1)); - ymm2 = _mm256_loadu_ps((float const *)(b11 + cs_b *2)); - + xmm0 = _mm_loadu_ps((float const *)(b11)); + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0); + xmm0 = _mm_loadu_ps((float const *)(b11 + cs_b)); + ymm1 = _mm256_insertf128_ps(ymm1, xmm0, 0); + xmm0 = _mm_loadu_ps((float const *)(b11 + cs_b*2)); + ymm2 = _mm256_insertf128_ps(ymm2, xmm0, 0); ymm17 = _mm256_shuffle_ps(ymm16, ymm16, 0x11); @@ -40812,7 +44541,7 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB if(2 == n_rem) { ///GEMM code begins/// - BLIS_CTRSM_SMALL_GEMM_4mx2n(a10,b01,cs_b, + BLIS_CTRSM_SMALL_GEMM_2mx2n(a10,b01,cs_b, p_lda,k_iter) BLIS_PRE_CTRSM_SMALL_2M_2N(AlphaVal,b11,cs_b) @@ -40829,7 +44558,7 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB else if(1 == n_rem) { ///GEMM code begins/// - BLIS_CTRSM_SMALL_GEMM_4mx1n(a10,b01,cs_b, + BLIS_CTRSM_SMALL_GEMM_2mx1n(a10,b01,cs_b, p_lda,k_iter) BLIS_PRE_CTRSM_SMALL_2M_1N(AlphaVal,b11,cs_b) @@ -40901,14 +44630,17 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB k_iter = i; BLIS_SET_S_YMM_REG_ZEROS - BLIS_CTRSM_SMALL_GEMM_4mx3n(a10,b01,cs_b,p_lda,k_iter) + BLIS_CTRSM_SMALL_GEMM_1mx3n(a10,b01,cs_b,p_lda,k_iter) ymm16 = _mm256_broadcast_ps(( __m128 const *)(&AlphaVal)); ymm16 = _mm256_permute_ps(ymm16, 0x44); - ymm0 = _mm256_loadu_ps((float const *)(b11)); - ymm1 = _mm256_loadu_ps((float const *)(b11 + cs_b *1)); - ymm2 = _mm256_loadu_ps((float const *)(b11 + cs_b *2)); + xmm0 = _mm_loadl_pi(xmm0,(__m64 *)(b11)); + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0); + xmm0 = _mm_loadl_pi(xmm0,(__m64 *)(b11 + cs_b)); + ymm1 = _mm256_insertf128_ps(ymm1, xmm0, 0); + xmm0 = _mm_loadl_pi(xmm0,(__m64 *)(b11 + cs_b*2)); + ymm2 = _mm256_insertf128_ps(ymm2, xmm0, 0); ymm17 = _mm256_shuffle_ps(ymm16, ymm16, 0x11); @@ -40964,7 +44696,7 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB if(2 == n_rem) { ///GEMM code begins/// - BLIS_CTRSM_SMALL_GEMM_4mx2n(a10,b01,cs_b, + BLIS_CTRSM_SMALL_GEMM_1mx2n(a10,b01,cs_b, p_lda,k_iter) BLIS_PRE_CTRSM_SMALL_1M_2N(AlphaVal,b11, @@ -40984,7 +44716,7 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB else if(1 == n_rem) { ///GEMM code begins/// - BLIS_CTRSM_SMALL_GEMM_4mx1n(a10,b01,cs_b, + BLIS_CTRSM_SMALL_GEMM_1mx1n(a10,b01,cs_b, p_lda,k_iter) BLIS_PRE_CTRSM_SMALL_1M_1N(AlphaVal,b11, @@ -41046,13 +44778,17 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB dim_t i, j, k; //loop variables dim_t k_iter; //number of times GEMM to be performed - scomplex AlphaVal = *(scomplex *)AlphaObj->buffer; //value of alpha - scomplex *L = a->buffer; //pointer to matrix A - scomplex *B = b->buffer; //pointer to matrix B + scomplex AlphaVal[2]; + AlphaVal[0] = *(scomplex *)AlphaObj->buffer; //value of alpha + AlphaVal[1] = *(scomplex *)AlphaObj->buffer; //value of alpha + + scomplex *L = bli_obj_buffer_at_off(a); //pointer to matrix A + scomplex *B = bli_obj_buffer_at_off(b); //pointer to matrix B scomplex *a10, *a11, *b01, *b11; //pointers that point to blocks for GEMM and TRSM - scomplex ones = {1.0, 1.0}; + float ones = 1.0; + float ones_a[4] = {1.0, 1.0,1.0,1.0}; bool is_unitdiag = bli_obj_has_unit_diag(a); //scratch registers @@ -41063,6 +44799,14 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB __m256 ymm16, ymm17, ymm18, ymm19; __m128 xmm0, xmm1, xmm2, xmm3, xmm4; + __m128 xmm5; + + xmm0 = _mm_setzero_ps(); + xmm1 = _mm_setzero_ps(); + xmm2 = _mm_setzero_ps(); + xmm3 = _mm_setzero_ps(); + xmm4 = _mm_setzero_ps(); + xmm5 = _mm_setzero_ps(); gint_t required_packing_A = 1; mem_t local_mem_buf_A_s = {0}; @@ -41181,14 +44925,24 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB */ ////extract a00 ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 7)); + ymm1 = _mm256_set_ps((d11_pack + 7)->imag,(d11_pack + 7)->real, + (d11_pack + 7)->imag,(d11_pack + 7)->real, + (d11_pack + 7)->imag,(d11_pack + 7)->real, + (d11_pack + 7)->imag,(d11_pack + 7)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm15) #else BLIS_CTRSM_MUL(ymm15) #endif - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*6 + 7*rs_a)); + ymm2 = _mm256_set_ps((a11 + cs_a*6 + 7*rs_a)->imag, + (a11 + cs_a*6 + 7*rs_a)->real, + (a11 + cs_a*6 + 7*rs_a)->imag, + (a11 + cs_a*6 + 7*rs_a)->real, + (a11 + cs_a*6 + 7*rs_a)->imag, + (a11 + cs_a*6 + 7*rs_a)->real, + (a11 + cs_a*6 + 7*rs_a)->imag, + (a11 + cs_a*6 + 7*rs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -41204,7 +44958,14 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm14 = _mm256_sub_ps(ymm14,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*5 + 7*rs_a)); + ymm2 = _mm256_set_ps((a11 + cs_a*5 + 7*rs_a)->imag, + (a11 + cs_a*5 + 7*rs_a)->real, + (a11 + cs_a*5 + 7*rs_a)->imag, + (a11 + cs_a*5 + 7*rs_a)->real, + (a11 + cs_a*5 + 7*rs_a)->imag, + (a11 + cs_a*5 + 7*rs_a)->real, + (a11 + cs_a*5 + 7*rs_a)->imag, + (a11 + cs_a*5 + 7*rs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -41219,7 +44980,14 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm13 = _mm256_sub_ps(ymm13,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*4 + 7*rs_a)); + ymm2 = _mm256_set_ps((a11 + cs_a*4 + 7*rs_a)->imag, + (a11 + cs_a*4 + 7*rs_a)->real, + (a11 + cs_a*4 + 7*rs_a)->imag, + (a11 + cs_a*4 + 7*rs_a)->real, + (a11 + cs_a*4 + 7*rs_a)->imag, + (a11 + cs_a*4 + 7*rs_a)->real, + (a11 + cs_a*4 + 7*rs_a)->imag, + (a11 + cs_a*4 + 7*rs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -41234,7 +45002,14 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm12 = _mm256_sub_ps(ymm12,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*3 + 7*rs_a)); + ymm2 = _mm256_set_ps((a11 + cs_a*3 + 7*rs_a)->imag, + (a11 + cs_a*3 + 7*rs_a)->real, + (a11 + cs_a*3 + 7*rs_a)->imag, + (a11 + cs_a*3 + 7*rs_a)->real, + (a11 + cs_a*3 + 7*rs_a)->imag, + (a11 + cs_a*3 + 7*rs_a)->real, + (a11 + cs_a*3 + 7*rs_a)->imag, + (a11 + cs_a*3 + 7*rs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -41249,7 +45024,14 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm11 = _mm256_sub_ps(ymm11,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*2 + 7*rs_a)); + ymm2 = _mm256_set_ps((a11 + cs_a*2 + 7*rs_a)->imag, + (a11 + cs_a*2 + 7*rs_a)->real, + (a11 + cs_a*2 + 7*rs_a)->imag, + (a11 + cs_a*2 + 7*rs_a)->real, + (a11 + cs_a*2 + 7*rs_a)->imag, + (a11 + cs_a*2 + 7*rs_a)->real, + (a11 + cs_a*2 + 7*rs_a)->imag, + (a11 + cs_a*2 + 7*rs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -41264,7 +45046,14 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm10 = _mm256_sub_ps(ymm10,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*1 + 7*rs_a)); + ymm2 = _mm256_set_ps((a11 + cs_a*1 + 7*rs_a)->imag, + (a11 + cs_a*1 + 7*rs_a)->real, + (a11 + cs_a*1 + 7*rs_a)->imag, + (a11 + cs_a*1 + 7*rs_a)->real, + (a11 + cs_a*1 + 7*rs_a)->imag, + (a11 + cs_a*1 + 7*rs_a)->real, + (a11 + cs_a*1 + 7*rs_a)->imag, + (a11 + cs_a*1 + 7*rs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -41279,7 +45068,14 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm9 = _mm256_sub_ps(ymm9,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*0 + 7*rs_a)); + ymm2 = _mm256_set_ps((a11 + cs_a*0 + 7*rs_a)->imag, + (a11 + cs_a*0 + 7*rs_a)->real, + (a11 + cs_a*0 + 7*rs_a)->imag, + (a11 + cs_a*0 + 7*rs_a)->real, + (a11 + cs_a*0 + 7*rs_a)->imag, + (a11 + cs_a*0 + 7*rs_a)->real, + (a11 + cs_a*0 + 7*rs_a)->imag, + (a11 + cs_a*0 + 7*rs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -41294,7 +45090,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm8 = _mm256_sub_ps(ymm8,ymm16); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 6)); + ymm1 = _mm256_set_ps((d11_pack + 6)->imag,(d11_pack + 6)->real, + (d11_pack + 6)->imag,(d11_pack + 6)->real, + (d11_pack + 6)->imag,(d11_pack + 6)->real, + (d11_pack + 6)->imag,(d11_pack + 6)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm14) @@ -41302,7 +45101,14 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB BLIS_CTRSM_MUL(ymm14) #endif - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*5 + 6*rs_a)); + ymm2 = _mm256_set_ps((a11 + cs_a*5 + 6*rs_a)->imag, + (a11 + cs_a*5 + 6*rs_a)->real, + (a11 + cs_a*5 + 6*rs_a)->imag, + (a11 + cs_a*5 + 6*rs_a)->real, + (a11 + cs_a*5 + 6*rs_a)->imag, + (a11 + cs_a*5 + 6*rs_a)->real, + (a11 + cs_a*5 + 6*rs_a)->imag, + (a11 + cs_a*5 + 6*rs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -41317,7 +45123,14 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm13 = _mm256_sub_ps(ymm13,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*4 + 6*rs_a)); + ymm2 = _mm256_set_ps((a11 + cs_a*4 + 6*rs_a)->imag, + (a11 + cs_a*4 + 6*rs_a)->real, + (a11 + cs_a*4 + 6*rs_a)->imag, + (a11 + cs_a*4 + 6*rs_a)->real, + (a11 + cs_a*4 + 6*rs_a)->imag, + (a11 + cs_a*4 + 6*rs_a)->real, + (a11 + cs_a*4 + 6*rs_a)->imag, + (a11 + cs_a*4 + 6*rs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -41332,7 +45145,14 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm12 = _mm256_sub_ps(ymm12,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*3 + 6*rs_a)); + ymm2 = _mm256_set_ps((a11 + cs_a*3 + 6*rs_a)->imag, + (a11 + cs_a*3 + 6*rs_a)->real, + (a11 + cs_a*3 + 6*rs_a)->imag, + (a11 + cs_a*3 + 6*rs_a)->real, + (a11 + cs_a*3 + 6*rs_a)->imag, + (a11 + cs_a*3 + 6*rs_a)->real, + (a11 + cs_a*3 + 6*rs_a)->imag, + (a11 + cs_a*3 + 6*rs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -41347,7 +45167,14 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm11 = _mm256_sub_ps(ymm11,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*2 + 6*rs_a)); + ymm2 = _mm256_set_ps((a11 + cs_a*2 + 6*rs_a)->imag, + (a11 + cs_a*2 + 6*rs_a)->real, + (a11 + cs_a*2 + 6*rs_a)->imag, + (a11 + cs_a*2 + 6*rs_a)->real, + (a11 + cs_a*2 + 6*rs_a)->imag, + (a11 + cs_a*2 + 6*rs_a)->real, + (a11 + cs_a*2 + 6*rs_a)->imag, + (a11 + cs_a*2 + 6*rs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -41362,7 +45189,14 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm10 = _mm256_sub_ps(ymm10,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*1 + 6*rs_a)); + ymm2 = _mm256_set_ps((a11 + cs_a*1 + 6*rs_a)->imag, + (a11 + cs_a*1 + 6*rs_a)->real, + (a11 + cs_a*1 + 6*rs_a)->imag, + (a11 + cs_a*1 + 6*rs_a)->real, + (a11 + cs_a*1 + 6*rs_a)->imag, + (a11 + cs_a*1 + 6*rs_a)->real, + (a11 + cs_a*1 + 6*rs_a)->imag, + (a11 + cs_a*1 + 6*rs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -41377,7 +45211,14 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm9 = _mm256_sub_ps(ymm9,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*0 + 6*rs_a)); + ymm2 = _mm256_set_ps((a11 + cs_a*0 + 6*rs_a)->imag, + (a11 + cs_a*0 + 6*rs_a)->real, + (a11 + cs_a*0 + 6*rs_a)->imag, + (a11 + cs_a*0 + 6*rs_a)->real, + (a11 + cs_a*0 + 6*rs_a)->imag, + (a11 + cs_a*0 + 6*rs_a)->real, + (a11 + cs_a*0 + 6*rs_a)->imag, + (a11 + cs_a*0 + 6*rs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -41393,7 +45234,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB ymm8 = _mm256_sub_ps(ymm8,ymm16); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 5)); + ymm1 = _mm256_set_ps((d11_pack + 5)->imag,(d11_pack + 5)->real, + (d11_pack + 5)->imag,(d11_pack + 5)->real, + (d11_pack + 5)->imag,(d11_pack + 5)->real, + (d11_pack + 5)->imag,(d11_pack + 5)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION @@ -41403,7 +45247,14 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB #endif - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*4 + 5*rs_a)); + ymm2 = _mm256_set_ps((a11 + cs_a*4 + 5*rs_a)->imag, + (a11 + cs_a*4 + 5*rs_a)->real, + (a11 + cs_a*4 + 5*rs_a)->imag, + (a11 + cs_a*4 + 5*rs_a)->real, + (a11 + cs_a*4 + 5*rs_a)->imag, + (a11 + cs_a*4 + 5*rs_a)->real, + (a11 + cs_a*4 + 5*rs_a)->imag, + (a11 + cs_a*4 + 5*rs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -41418,7 +45269,14 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm12 = _mm256_sub_ps(ymm12,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*3 + 5*rs_a)); + ymm2 = _mm256_set_ps((a11 + cs_a*3 + 5*rs_a)->imag, + (a11 + cs_a*3 + 5*rs_a)->real, + (a11 + cs_a*3 + 5*rs_a)->imag, + (a11 + cs_a*3 + 5*rs_a)->real, + (a11 + cs_a*3 + 5*rs_a)->imag, + (a11 + cs_a*3 + 5*rs_a)->real, + (a11 + cs_a*3 + 5*rs_a)->imag, + (a11 + cs_a*3 + 5*rs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -41433,7 +45291,14 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm11 = _mm256_sub_ps(ymm11,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*2 + 5*rs_a)); + ymm2 = _mm256_set_ps((a11 + cs_a*2 + 5*rs_a)->imag, + (a11 + cs_a*2 + 5*rs_a)->real, + (a11 + cs_a*2 + 5*rs_a)->imag, + (a11 + cs_a*2 + 5*rs_a)->real, + (a11 + cs_a*2 + 5*rs_a)->imag, + (a11 + cs_a*2 + 5*rs_a)->real, + (a11 + cs_a*2 + 5*rs_a)->imag, + (a11 + cs_a*2 + 5*rs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -41448,7 +45313,14 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm10 = _mm256_sub_ps(ymm10,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*1 + 5*rs_a)); + ymm2 = _mm256_set_ps((a11 + cs_a*1 + 5*rs_a)->imag, + (a11 + cs_a*1 + 5*rs_a)->real, + (a11 + cs_a*1 + 5*rs_a)->imag, + (a11 + cs_a*1 + 5*rs_a)->real, + (a11 + cs_a*1 + 5*rs_a)->imag, + (a11 + cs_a*1 + 5*rs_a)->real, + (a11 + cs_a*1 + 5*rs_a)->imag, + (a11 + cs_a*1 + 5*rs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -41463,7 +45335,14 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm9 = _mm256_sub_ps(ymm9,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*0 + 5*rs_a)); + ymm2 = _mm256_set_ps((a11 + cs_a*0 + 5*rs_a)->imag, + (a11 + cs_a*0 + 5*rs_a)->real, + (a11 + cs_a*0 + 5*rs_a)->imag, + (a11 + cs_a*0 + 5*rs_a)->real, + (a11 + cs_a*0 + 5*rs_a)->imag, + (a11 + cs_a*0 + 5*rs_a)->real, + (a11 + cs_a*0 + 5*rs_a)->imag, + (a11 + cs_a*0 + 5*rs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -41479,7 +45358,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB ymm8 = _mm256_sub_ps(ymm8,ymm16); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 4)); + ymm1 = _mm256_set_ps((d11_pack + 4)->imag,(d11_pack + 4)->real, + (d11_pack + 4)->imag,(d11_pack + 4)->real, + (d11_pack + 4)->imag,(d11_pack + 4)->real, + (d11_pack + 4)->imag,(d11_pack + 4)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm12) @@ -41488,7 +45370,14 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB #endif - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*3 + 4*rs_a)); + ymm2 = _mm256_set_ps((a11 + cs_a*3 + 4*rs_a)->imag, + (a11 + cs_a*3 + 4*rs_a)->real, + (a11 + cs_a*3 + 4*rs_a)->imag, + (a11 + cs_a*3 + 4*rs_a)->real, + (a11 + cs_a*3 + 4*rs_a)->imag, + (a11 + cs_a*3 + 4*rs_a)->real, + (a11 + cs_a*3 + 4*rs_a)->imag, + (a11 + cs_a*3 + 4*rs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -41503,7 +45392,14 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm11 = _mm256_sub_ps(ymm11,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*2 + 4*rs_a)); + ymm2 = _mm256_set_ps((a11 + cs_a*2 + 4*rs_a)->imag, + (a11 + cs_a*2 + 4*rs_a)->real, + (a11 + cs_a*2 + 4*rs_a)->imag, + (a11 + cs_a*2 + 4*rs_a)->real, + (a11 + cs_a*2 + 4*rs_a)->imag, + (a11 + cs_a*2 + 4*rs_a)->real, + (a11 + cs_a*2 + 4*rs_a)->imag, + (a11 + cs_a*2 + 4*rs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -41518,7 +45414,14 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm10 = _mm256_sub_ps(ymm10,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*1 + 4*rs_a)); + ymm2 = _mm256_set_ps((a11 + cs_a*1 + 4*rs_a)->imag, + (a11 + cs_a*1 + 4*rs_a)->real, + (a11 + cs_a*1 + 4*rs_a)->imag, + (a11 + cs_a*1 + 4*rs_a)->real, + (a11 + cs_a*1 + 4*rs_a)->imag, + (a11 + cs_a*1 + 4*rs_a)->real, + (a11 + cs_a*1 + 4*rs_a)->imag, + (a11 + cs_a*1 + 4*rs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -41533,7 +45436,14 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm9 = _mm256_sub_ps(ymm9,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*0 + 4*rs_a)); + ymm2 = _mm256_set_ps((a11 + cs_a*0 + 4*rs_a)->imag, + (a11 + cs_a*0 + 4*rs_a)->real, + (a11 + cs_a*0 + 4*rs_a)->imag, + (a11 + cs_a*0 + 4*rs_a)->real, + (a11 + cs_a*0 + 4*rs_a)->imag, + (a11 + cs_a*0 + 4*rs_a)->real, + (a11 + cs_a*0 + 4*rs_a)->imag, + (a11 + cs_a*0 + 4*rs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -41549,7 +45459,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB ymm8 = _mm256_sub_ps(ymm8,ymm16); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 3)); + ymm1 = _mm256_set_ps((d11_pack + 3)->imag,(d11_pack + 3)->real, + (d11_pack + 3)->imag,(d11_pack + 3)->real, + (d11_pack + 3)->imag,(d11_pack + 3)->real, + (d11_pack + 3)->imag,(d11_pack + 3)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm11) @@ -41557,7 +45470,14 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB BLIS_CTRSM_MUL(ymm11) #endif - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*2 + 3*rs_a)); + ymm2 = _mm256_set_ps((a11 + cs_a*2 + 3*rs_a)->imag, + (a11 + cs_a*2 + 3*rs_a)->real, + (a11 + cs_a*2 + 3*rs_a)->imag, + (a11 + cs_a*2 + 3*rs_a)->real, + (a11 + cs_a*2 + 3*rs_a)->imag, + (a11 + cs_a*2 + 3*rs_a)->real, + (a11 + cs_a*2 + 3*rs_a)->imag, + (a11 + cs_a*2 + 3*rs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -41572,7 +45492,14 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm10 = _mm256_sub_ps(ymm10,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*1 + 3*rs_a)); + ymm2 = _mm256_set_ps((a11 + cs_a*1 + 3*rs_a)->imag, + (a11 + cs_a*1 + 3*rs_a)->real, + (a11 + cs_a*1 + 3*rs_a)->imag, + (a11 + cs_a*1 + 3*rs_a)->real, + (a11 + cs_a*1 + 3*rs_a)->imag, + (a11 + cs_a*1 + 3*rs_a)->real, + (a11 + cs_a*1 + 3*rs_a)->imag, + (a11 + cs_a*1 + 3*rs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -41587,7 +45514,14 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm9 = _mm256_sub_ps(ymm9,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*0 + 3*rs_a)); + ymm2 = _mm256_set_ps((a11 + cs_a*0 + 3*rs_a)->imag, + (a11 + cs_a*0 + 3*rs_a)->real, + (a11 + cs_a*0 + 3*rs_a)->imag, + (a11 + cs_a*0 + 3*rs_a)->real, + (a11 + cs_a*0 + 3*rs_a)->imag, + (a11 + cs_a*0 + 3*rs_a)->real, + (a11 + cs_a*0 + 3*rs_a)->imag, + (a11 + cs_a*0 + 3*rs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -41603,7 +45537,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB ymm8 = _mm256_sub_ps(ymm8,ymm16); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 2)); + ymm1 = _mm256_set_ps((d11_pack + 2)->imag,(d11_pack + 2)->real, + (d11_pack + 2)->imag,(d11_pack + 2)->real, + (d11_pack + 2)->imag,(d11_pack + 2)->real, + (d11_pack + 2)->imag,(d11_pack + 2)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm10) @@ -41612,7 +45549,14 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB #endif - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a + 2*rs_a)); + ymm2 = _mm256_set_ps((a11 + cs_a*1 + 2*rs_a)->imag, + (a11 + cs_a*1 + 2*rs_a)->real, + (a11 + cs_a*1 + 2*rs_a)->imag, + (a11 + cs_a*1 + 2*rs_a)->real, + (a11 + cs_a*1 + 2*rs_a)->imag, + (a11 + cs_a*1 + 2*rs_a)->real, + (a11 + cs_a*1 + 2*rs_a)->imag, + (a11 + cs_a*1 + 2*rs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -41627,7 +45571,14 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm9 = _mm256_sub_ps(ymm9,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*0 + 2*rs_a)); + ymm2 = _mm256_set_ps((a11 + cs_a*0 + 2*rs_a)->imag, + (a11 + cs_a*0 + 2*rs_a)->real, + (a11 + cs_a*0 + 2*rs_a)->imag, + (a11 + cs_a*0 + 2*rs_a)->real, + (a11 + cs_a*0 + 2*rs_a)->imag, + (a11 + cs_a*0 + 2*rs_a)->real, + (a11 + cs_a*0 + 2*rs_a)->imag, + (a11 + cs_a*0 + 2*rs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -41643,7 +45594,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB ymm8 = _mm256_sub_ps(ymm8,ymm16); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 1)); + ymm1 = _mm256_set_ps((d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm9) @@ -41652,7 +45606,14 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB #endif - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + rs_a) ); + ymm2 = _mm256_set_ps((a11 + cs_a*0 + 1*rs_a)->imag, + (a11 + cs_a*0 + 1*rs_a)->real, + (a11 + cs_a*0 + 1*rs_a)->imag, + (a11 + cs_a*0 + 1*rs_a)->real, + (a11 + cs_a*0 + 1*rs_a)->imag, + (a11 + cs_a*0 + 1*rs_a)->real, + (a11 + cs_a*0 + 1*rs_a)->imag, + (a11 + cs_a*0 + 1*rs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -41668,7 +45629,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB ymm8 = _mm256_sub_ps(ymm8,ymm16); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack)); + ymm1 = _mm256_set_ps((d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm8) @@ -41699,8 +45663,8 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB BLIS_CTRSM_SMALL_GEMM_8mx2n(a10,b01,cs_b,p_lda,k_iter) float zero = 0.0; - ymm16 = _mm256_broadcast_ss(&AlphaVal.real); - ymm17 = _mm256_broadcast_ss(&AlphaVal.imag); + ymm16 = _mm256_broadcast_ss(&AlphaVal[0].real); + ymm17 = _mm256_broadcast_ss(&AlphaVal[0].imag); ymm2 = _mm256_broadcast_ss(&zero); ymm3 = _mm256_broadcast_ss(&zero); ymm6 = _mm256_broadcast_ss(&zero); @@ -41751,8 +45715,8 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB BLIS_CTRSM_SMALL_GEMM_8mx1n(a10,b01,cs_b,p_lda,k_iter) float zero = 0.0; - ymm16 = _mm256_broadcast_ss(&AlphaVal.real); - ymm17 = _mm256_broadcast_ss(&AlphaVal.imag); + ymm16 = _mm256_broadcast_ss(&AlphaVal[0].real); + ymm17 = _mm256_broadcast_ss(&AlphaVal[0].imag); ymm2 = _mm256_broadcast_ss(&zero); ymm3 = _mm256_broadcast_ss(&zero); ymm6 = _mm256_broadcast_ss(&zero); @@ -41821,14 +45785,24 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 7)); + ymm1 = _mm256_set_ps((d11_pack + 7)->imag,(d11_pack + 7)->real, + (d11_pack + 7)->imag,(d11_pack + 7)->real, + (d11_pack + 7)->imag,(d11_pack + 7)->real, + (d11_pack + 7)->imag,(d11_pack + 7)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm15) #else BLIS_CTRSM_MUL(ymm15) #endif - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*6 + 7*rs_a)); + ymm2 = _mm256_set_ps((a11 + cs_a*6 + 7*rs_a)->imag, + (a11 + cs_a*6 + 7*rs_a)->real, + (a11 + cs_a*6 + 7*rs_a)->imag, + (a11 + cs_a*6 + 7*rs_a)->real, + (a11 + cs_a*6 + 7*rs_a)->imag, + (a11 + cs_a*6 + 7*rs_a)->real, + (a11 + cs_a*6 + 7*rs_a)->imag, + (a11 + cs_a*6 + 7*rs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -41845,7 +45819,14 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm14 = _mm256_sub_ps(ymm14,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*5 + 7*rs_a)); + ymm2 = _mm256_set_ps((a11 + cs_a*5 + 7*rs_a)->imag, + (a11 + cs_a*5 + 7*rs_a)->real, + (a11 + cs_a*5 + 7*rs_a)->imag, + (a11 + cs_a*5 + 7*rs_a)->real, + (a11 + cs_a*5 + 7*rs_a)->imag, + (a11 + cs_a*5 + 7*rs_a)->real, + (a11 + cs_a*5 + 7*rs_a)->imag, + (a11 + cs_a*5 + 7*rs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -41860,7 +45841,14 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm13 = _mm256_sub_ps(ymm13,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*4 + 7*rs_a)); + ymm2 = _mm256_set_ps((a11 + cs_a*4 + 7*rs_a)->imag, + (a11 + cs_a*4 + 7*rs_a)->real, + (a11 + cs_a*4 + 7*rs_a)->imag, + (a11 + cs_a*4 + 7*rs_a)->real, + (a11 + cs_a*4 + 7*rs_a)->imag, + (a11 + cs_a*4 + 7*rs_a)->real, + (a11 + cs_a*4 + 7*rs_a)->imag, + (a11 + cs_a*4 + 7*rs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -41875,7 +45863,14 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm12 = _mm256_sub_ps(ymm12,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*3 + 7*rs_a)); + ymm2 = _mm256_set_ps((a11 + cs_a*3 + 7*rs_a)->imag, + (a11 + cs_a*3 + 7*rs_a)->real, + (a11 + cs_a*3 + 7*rs_a)->imag, + (a11 + cs_a*3 + 7*rs_a)->real, + (a11 + cs_a*3 + 7*rs_a)->imag, + (a11 + cs_a*3 + 7*rs_a)->real, + (a11 + cs_a*3 + 7*rs_a)->imag, + (a11 + cs_a*3 + 7*rs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -41890,7 +45885,14 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm11 = _mm256_sub_ps(ymm11,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*2 + 7*rs_a)); + ymm2 = _mm256_set_ps((a11 + cs_a*2 + 7*rs_a)->imag, + (a11 + cs_a*2 + 7*rs_a)->real, + (a11 + cs_a*2 + 7*rs_a)->imag, + (a11 + cs_a*2 + 7*rs_a)->real, + (a11 + cs_a*2 + 7*rs_a)->imag, + (a11 + cs_a*2 + 7*rs_a)->real, + (a11 + cs_a*2 + 7*rs_a)->imag, + (a11 + cs_a*2 + 7*rs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -41905,7 +45907,14 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm10 = _mm256_sub_ps(ymm10,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*1 + 7*rs_a)); + ymm2 = _mm256_set_ps((a11 + cs_a*1 + 7*rs_a)->imag, + (a11 + cs_a*1 + 7*rs_a)->real, + (a11 + cs_a*1 + 7*rs_a)->imag, + (a11 + cs_a*1 + 7*rs_a)->real, + (a11 + cs_a*1 + 7*rs_a)->imag, + (a11 + cs_a*1 + 7*rs_a)->real, + (a11 + cs_a*1 + 7*rs_a)->imag, + (a11 + cs_a*1 + 7*rs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -41920,7 +45929,14 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm9 = _mm256_sub_ps(ymm9,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*0 + 7*rs_a)); + ymm2 = _mm256_set_ps((a11 + cs_a*0 + 7*rs_a)->imag, + (a11 + cs_a*0 + 7*rs_a)->real, + (a11 + cs_a*0 + 7*rs_a)->imag, + (a11 + cs_a*0 + 7*rs_a)->real, + (a11 + cs_a*0 + 7*rs_a)->imag, + (a11 + cs_a*0 + 7*rs_a)->real, + (a11 + cs_a*0 + 7*rs_a)->imag, + (a11 + cs_a*0 + 7*rs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -41935,7 +45951,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm8 = _mm256_sub_ps(ymm8,ymm16); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 6)); + ymm1 = _mm256_set_ps((d11_pack + 6)->imag,(d11_pack + 6)->real, + (d11_pack + 6)->imag,(d11_pack + 6)->real, + (d11_pack + 6)->imag,(d11_pack + 6)->real, + (d11_pack + 6)->imag,(d11_pack + 6)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm14) @@ -41943,7 +45962,14 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB BLIS_CTRSM_MUL(ymm14) #endif - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*5 + 6*rs_a)); + ymm2 = _mm256_set_ps((a11 + cs_a*5 + 6*rs_a)->imag, + (a11 + cs_a*5 + 6*rs_a)->real, + (a11 + cs_a*5 + 6*rs_a)->imag, + (a11 + cs_a*5 + 6*rs_a)->real, + (a11 + cs_a*5 + 6*rs_a)->imag, + (a11 + cs_a*5 + 6*rs_a)->real, + (a11 + cs_a*5 + 6*rs_a)->imag, + (a11 + cs_a*5 + 6*rs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -41958,7 +45984,14 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm13 = _mm256_sub_ps(ymm13,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*4 + 6*rs_a)); + ymm2 = _mm256_set_ps((a11 + cs_a*4 + 6*rs_a)->imag, + (a11 + cs_a*4 + 6*rs_a)->real, + (a11 + cs_a*4 + 6*rs_a)->imag, + (a11 + cs_a*4 + 6*rs_a)->real, + (a11 + cs_a*4 + 6*rs_a)->imag, + (a11 + cs_a*4 + 6*rs_a)->real, + (a11 + cs_a*4 + 6*rs_a)->imag, + (a11 + cs_a*4 + 6*rs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -41973,7 +46006,14 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm12 = _mm256_sub_ps(ymm12,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*3 + 6*rs_a)); + ymm2 = _mm256_set_ps((a11 + cs_a*3 + 6*rs_a)->imag, + (a11 + cs_a*3 + 6*rs_a)->real, + (a11 + cs_a*3 + 6*rs_a)->imag, + (a11 + cs_a*3 + 6*rs_a)->real, + (a11 + cs_a*3 + 6*rs_a)->imag, + (a11 + cs_a*3 + 6*rs_a)->real, + (a11 + cs_a*3 + 6*rs_a)->imag, + (a11 + cs_a*3 + 6*rs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -41988,7 +46028,14 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm11 = _mm256_sub_ps(ymm11,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*2 + 6*rs_a)); + ymm2 = _mm256_set_ps((a11 + cs_a*2 + 6*rs_a)->imag, + (a11 + cs_a*2 + 6*rs_a)->real, + (a11 + cs_a*2 + 6*rs_a)->imag, + (a11 + cs_a*2 + 6*rs_a)->real, + (a11 + cs_a*2 + 6*rs_a)->imag, + (a11 + cs_a*2 + 6*rs_a)->real, + (a11 + cs_a*2 + 6*rs_a)->imag, + (a11 + cs_a*2 + 6*rs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -42003,7 +46050,14 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm10 = _mm256_sub_ps(ymm10,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*1 + 6*rs_a)); + ymm2 = _mm256_set_ps((a11 + cs_a*1 + 6*rs_a)->imag, + (a11 + cs_a*1 + 6*rs_a)->real, + (a11 + cs_a*1 + 6*rs_a)->imag, + (a11 + cs_a*1 + 6*rs_a)->real, + (a11 + cs_a*1 + 6*rs_a)->imag, + (a11 + cs_a*1 + 6*rs_a)->real, + (a11 + cs_a*1 + 6*rs_a)->imag, + (a11 + cs_a*1 + 6*rs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -42018,7 +46072,14 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm9 = _mm256_sub_ps(ymm9,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*0 + 6*rs_a)); + ymm2 = _mm256_set_ps((a11 + cs_a*0 + 6*rs_a)->imag, + (a11 + cs_a*0 + 6*rs_a)->real, + (a11 + cs_a*0 + 6*rs_a)->imag, + (a11 + cs_a*0 + 6*rs_a)->real, + (a11 + cs_a*0 + 6*rs_a)->imag, + (a11 + cs_a*0 + 6*rs_a)->real, + (a11 + cs_a*0 + 6*rs_a)->imag, + (a11 + cs_a*0 + 6*rs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -42034,7 +46095,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB ymm8 = _mm256_sub_ps(ymm8,ymm16); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 5)); + ymm1 = _mm256_set_ps((d11_pack + 5)->imag,(d11_pack + 5)->real, + (d11_pack + 5)->imag,(d11_pack + 5)->real, + (d11_pack + 5)->imag,(d11_pack + 5)->real, + (d11_pack + 5)->imag,(d11_pack + 5)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION @@ -42044,7 +46108,14 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB #endif - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*4 + 5*rs_a)); + ymm2 = _mm256_set_ps((a11 + cs_a*4 + 5*rs_a)->imag, + (a11 + cs_a*4 + 5*rs_a)->real, + (a11 + cs_a*4 + 5*rs_a)->imag, + (a11 + cs_a*4 + 5*rs_a)->real, + (a11 + cs_a*4 + 5*rs_a)->imag, + (a11 + cs_a*4 + 5*rs_a)->real, + (a11 + cs_a*4 + 5*rs_a)->imag, + (a11 + cs_a*4 + 5*rs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -42059,7 +46130,14 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm12 = _mm256_sub_ps(ymm12,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*3 + 5*rs_a)); + ymm2 = _mm256_set_ps((a11 + cs_a*3 + 5*rs_a)->imag, + (a11 + cs_a*3 + 5*rs_a)->real, + (a11 + cs_a*3 + 5*rs_a)->imag, + (a11 + cs_a*3 + 5*rs_a)->real, + (a11 + cs_a*3 + 5*rs_a)->imag, + (a11 + cs_a*3 + 5*rs_a)->real, + (a11 + cs_a*3 + 5*rs_a)->imag, + (a11 + cs_a*3 + 5*rs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -42074,7 +46152,14 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm11 = _mm256_sub_ps(ymm11,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*2 + 5*rs_a)); + ymm2 = _mm256_set_ps((a11 + cs_a*2 + 5*rs_a)->imag, + (a11 + cs_a*2 + 5*rs_a)->real, + (a11 + cs_a*2 + 5*rs_a)->imag, + (a11 + cs_a*2 + 5*rs_a)->real, + (a11 + cs_a*2 + 5*rs_a)->imag, + (a11 + cs_a*2 + 5*rs_a)->real, + (a11 + cs_a*2 + 5*rs_a)->imag, + (a11 + cs_a*2 + 5*rs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -42089,7 +46174,14 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm10 = _mm256_sub_ps(ymm10,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*1 + 5*rs_a)); + ymm2 = _mm256_set_ps((a11 + cs_a*1 + 5*rs_a)->imag, + (a11 + cs_a*1 + 5*rs_a)->real, + (a11 + cs_a*1 + 5*rs_a)->imag, + (a11 + cs_a*1 + 5*rs_a)->real, + (a11 + cs_a*1 + 5*rs_a)->imag, + (a11 + cs_a*1 + 5*rs_a)->real, + (a11 + cs_a*1 + 5*rs_a)->imag, + (a11 + cs_a*1 + 5*rs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -42104,7 +46196,14 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm9 = _mm256_sub_ps(ymm9,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*0 + 5*rs_a)); + ymm2 = _mm256_set_ps((a11 + cs_a*0 + 5*rs_a)->imag, + (a11 + cs_a*0 + 5*rs_a)->real, + (a11 + cs_a*0 + 5*rs_a)->imag, + (a11 + cs_a*0 + 5*rs_a)->real, + (a11 + cs_a*0 + 5*rs_a)->imag, + (a11 + cs_a*0 + 5*rs_a)->real, + (a11 + cs_a*0 + 5*rs_a)->imag, + (a11 + cs_a*0 + 5*rs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -42120,7 +46219,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB ymm8 = _mm256_sub_ps(ymm8,ymm16); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 4)); + ymm1 = _mm256_set_ps((d11_pack + 4)->imag,(d11_pack + 4)->real, + (d11_pack + 4)->imag,(d11_pack + 4)->real, + (d11_pack + 4)->imag,(d11_pack + 4)->real, + (d11_pack + 4)->imag,(d11_pack + 4)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm12) @@ -42129,7 +46231,14 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB #endif - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*3 + 4*rs_a)); + ymm2 = _mm256_set_ps((a11 + cs_a*3 + 4*rs_a)->imag, + (a11 + cs_a*3 + 4*rs_a)->real, + (a11 + cs_a*3 + 4*rs_a)->imag, + (a11 + cs_a*3 + 4*rs_a)->real, + (a11 + cs_a*3 + 4*rs_a)->imag, + (a11 + cs_a*3 + 4*rs_a)->real, + (a11 + cs_a*3 + 4*rs_a)->imag, + (a11 + cs_a*3 + 4*rs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -42144,7 +46253,14 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm11 = _mm256_sub_ps(ymm11,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*2 + 4*rs_a)); + ymm2 = _mm256_set_ps((a11 + cs_a*2 + 4*rs_a)->imag, + (a11 + cs_a*2 + 4*rs_a)->real, + (a11 + cs_a*2 + 4*rs_a)->imag, + (a11 + cs_a*2 + 4*rs_a)->real, + (a11 + cs_a*2 + 4*rs_a)->imag, + (a11 + cs_a*2 + 4*rs_a)->real, + (a11 + cs_a*2 + 4*rs_a)->imag, + (a11 + cs_a*2 + 4*rs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -42159,7 +46275,14 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm10 = _mm256_sub_ps(ymm10,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*1 + 4*rs_a)); + ymm2 = _mm256_set_ps((a11 + cs_a*1 + 4*rs_a)->imag, + (a11 + cs_a*1 + 4*rs_a)->real, + (a11 + cs_a*1 + 4*rs_a)->imag, + (a11 + cs_a*1 + 4*rs_a)->real, + (a11 + cs_a*1 + 4*rs_a)->imag, + (a11 + cs_a*1 + 4*rs_a)->real, + (a11 + cs_a*1 + 4*rs_a)->imag, + (a11 + cs_a*1 + 4*rs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -42174,7 +46297,14 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm9 = _mm256_sub_ps(ymm9,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*0 + 4*rs_a)); + ymm2 = _mm256_set_ps((a11 + cs_a*0 + 4*rs_a)->imag, + (a11 + cs_a*0 + 4*rs_a)->real, + (a11 + cs_a*0 + 4*rs_a)->imag, + (a11 + cs_a*0 + 4*rs_a)->real, + (a11 + cs_a*0 + 4*rs_a)->imag, + (a11 + cs_a*0 + 4*rs_a)->real, + (a11 + cs_a*0 + 4*rs_a)->imag, + (a11 + cs_a*0 + 4*rs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -42190,7 +46320,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB ymm8 = _mm256_sub_ps(ymm8,ymm16); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 3)); + ymm1 = _mm256_set_ps((d11_pack + 3)->imag,(d11_pack + 3)->real, + (d11_pack + 3)->imag,(d11_pack + 3)->real, + (d11_pack + 3)->imag,(d11_pack + 3)->real, + (d11_pack + 3)->imag,(d11_pack + 3)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm11) @@ -42198,7 +46331,14 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB BLIS_CTRSM_MUL(ymm11) #endif - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*2 + 3*rs_a)); + ymm2 = _mm256_set_ps((a11 + cs_a*2 + 3*rs_a)->imag, + (a11 + cs_a*2 + 3*rs_a)->real, + (a11 + cs_a*2 + 3*rs_a)->imag, + (a11 + cs_a*2 + 3*rs_a)->real, + (a11 + cs_a*2 + 3*rs_a)->imag, + (a11 + cs_a*2 + 3*rs_a)->real, + (a11 + cs_a*2 + 3*rs_a)->imag, + (a11 + cs_a*2 + 3*rs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -42213,7 +46353,14 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm10 = _mm256_sub_ps(ymm10,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*1 + 3*rs_a)); + ymm2 = _mm256_set_ps((a11 + cs_a*1 + 3*rs_a)->imag, + (a11 + cs_a*1 + 3*rs_a)->real, + (a11 + cs_a*1 + 3*rs_a)->imag, + (a11 + cs_a*1 + 3*rs_a)->real, + (a11 + cs_a*1 + 3*rs_a)->imag, + (a11 + cs_a*1 + 3*rs_a)->real, + (a11 + cs_a*1 + 3*rs_a)->imag, + (a11 + cs_a*1 + 3*rs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -42228,7 +46375,14 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm9 = _mm256_sub_ps(ymm9,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*0 + 3*rs_a)); + ymm2 = _mm256_set_ps((a11 + cs_a*0 + 3*rs_a)->imag, + (a11 + cs_a*0 + 3*rs_a)->real, + (a11 + cs_a*0 + 3*rs_a)->imag, + (a11 + cs_a*0 + 3*rs_a)->real, + (a11 + cs_a*0 + 3*rs_a)->imag, + (a11 + cs_a*0 + 3*rs_a)->real, + (a11 + cs_a*0 + 3*rs_a)->imag, + (a11 + cs_a*0 + 3*rs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -42244,7 +46398,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB ymm8 = _mm256_sub_ps(ymm8,ymm16); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 2)); + ymm1 = _mm256_set_ps((d11_pack + 2)->imag,(d11_pack + 2)->real, + (d11_pack + 2)->imag,(d11_pack + 2)->real, + (d11_pack + 2)->imag,(d11_pack + 2)->real, + (d11_pack + 2)->imag,(d11_pack + 2)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm10) @@ -42253,7 +46410,15 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB #endif - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a + 2*rs_a)); + ymm2 = _mm256_set_ps((a11 + cs_a*1 + 2*rs_a)->imag, + (a11 + cs_a*1 + 2*rs_a)->real, + (a11 + cs_a*1 + 2*rs_a)->imag, + (a11 + cs_a*1 + 2*rs_a)->real, + (a11 + cs_a*1 + 2*rs_a)->imag, + (a11 + cs_a*1 + 2*rs_a)->real, + (a11 + cs_a*1 + 2*rs_a)->imag, + (a11 + cs_a*1 + 2*rs_a)->real); + ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -42268,7 +46433,14 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm9 = _mm256_sub_ps(ymm9,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*0 + 2*rs_a)); + ymm2 = _mm256_set_ps((a11 + cs_a*0 + 2*rs_a)->imag, + (a11 + cs_a*0 + 2*rs_a)->real, + (a11 + cs_a*0 + 2*rs_a)->imag, + (a11 + cs_a*0 + 2*rs_a)->real, + (a11 + cs_a*0 + 2*rs_a)->imag, + (a11 + cs_a*0 + 2*rs_a)->real, + (a11 + cs_a*0 + 2*rs_a)->imag, + (a11 + cs_a*0 + 2*rs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -42284,7 +46456,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB ymm8 = _mm256_sub_ps(ymm8,ymm16); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 1)); + ymm1 = _mm256_set_ps((d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm9) @@ -42293,7 +46468,14 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB #endif - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + rs_a) ); + ymm2 = _mm256_set_ps((a11 + cs_a*0 + 1*rs_a)->imag, + (a11 + cs_a*0 + 1*rs_a)->real, + (a11 + cs_a*0 + 1*rs_a)->imag, + (a11 + cs_a*0 + 1*rs_a)->real, + (a11 + cs_a*0 + 1*rs_a)->imag, + (a11 + cs_a*0 + 1*rs_a)->real, + (a11 + cs_a*0 + 1*rs_a)->imag, + (a11 + cs_a*0 + 1*rs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -42309,7 +46491,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB ymm8 = _mm256_sub_ps(ymm8,ymm16); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack)); + ymm1 = _mm256_set_ps((d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm8) @@ -42366,7 +46551,7 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB dim_t p_lda = 4; if(transa) { - for(dim_t x =0;x < m-i+4;x+=p_lda) + for(dim_t x =0;x < m-i-4;x+=p_lda) { ymm0 = _mm256_loadu_ps((float const *)(a10)); ymm1 = _mm256_loadu_ps((float const *)(a10 + cs_a)); @@ -42405,11 +46590,11 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB { if(transa) { - ctrsm_small_pack_diag_element(is_unitdiag,a11,cs_a,d11_pack,m_rem); + ctrsm_small_pack_diag_element(is_unitdiag,a11,cs_a,d11_pack,4); } else { - ctrsm_small_pack_diag_element(is_unitdiag,a11,rs_a,d11_pack,m_rem); + ctrsm_small_pack_diag_element(is_unitdiag,a11,rs_a,d11_pack,4); } } @@ -42427,7 +46612,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB BLIS_CTRSM_SMALL_NREG_TRANSPOSE_3x4(b11,cs_b,AlphaVal) ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 3)); + ymm1 = _mm256_set_ps((d11_pack + 3)->imag,(d11_pack + 3)->real, + (d11_pack + 3)->imag,(d11_pack + 3)->real, + (d11_pack + 3)->imag,(d11_pack + 3)->real, + (d11_pack + 3)->imag,(d11_pack + 3)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm11) @@ -42435,7 +46623,14 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB BLIS_CTRSM_MUL(ymm11) #endif - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*2 + 3*rs_a)); + ymm2 = _mm256_set_ps((a11 + cs_a*2 + 3*rs_a)->imag, + (a11 + cs_a*2 + 3*rs_a)->real, + (a11 + cs_a*2 + 3*rs_a)->imag, + (a11 + cs_a*2 + 3*rs_a)->real, + (a11 + cs_a*2 + 3*rs_a)->imag, + (a11 + cs_a*2 + 3*rs_a)->real, + (a11 + cs_a*2 + 3*rs_a)->imag, + (a11 + cs_a*2 + 3*rs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -42449,7 +46644,14 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm10 = _mm256_sub_ps(ymm10,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*1 + 3*rs_a)); + ymm2 = _mm256_set_ps((a11 + cs_a*1 + 3*rs_a)->imag, + (a11 + cs_a*1 + 3*rs_a)->real, + (a11 + cs_a*1 + 3*rs_a)->imag, + (a11 + cs_a*1 + 3*rs_a)->real, + (a11 + cs_a*1 + 3*rs_a)->imag, + (a11 + cs_a*1 + 3*rs_a)->real, + (a11 + cs_a*1 + 3*rs_a)->imag, + (a11 + cs_a*1 + 3*rs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -42463,7 +46665,14 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm9 = _mm256_sub_ps(ymm9,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*0 + 3*rs_a)); + ymm2 = _mm256_set_ps((a11 + cs_a*0 + 3*rs_a)->imag, + (a11 + cs_a*0 + 3*rs_a)->real, + (a11 + cs_a*0 + 3*rs_a)->imag, + (a11 + cs_a*0 + 3*rs_a)->real, + (a11 + cs_a*0 + 3*rs_a)->imag, + (a11 + cs_a*0 + 3*rs_a)->real, + (a11 + cs_a*0 + 3*rs_a)->imag, + (a11 + cs_a*0 + 3*rs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -42478,7 +46687,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB ymm8 = _mm256_sub_ps(ymm8,ymm16); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 2)); + ymm1 = _mm256_set_ps((d11_pack + 2)->imag,(d11_pack + 2)->real, + (d11_pack + 2)->imag,(d11_pack + 2)->real, + (d11_pack + 2)->imag,(d11_pack + 2)->real, + (d11_pack + 2)->imag,(d11_pack + 2)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm10) @@ -42487,7 +46699,14 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB #endif - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a + 2*rs_a)); + ymm2 = _mm256_set_ps((a11 + cs_a*1 + 2*rs_a)->imag, + (a11 + cs_a*1 + 2*rs_a)->real, + (a11 + cs_a*1 + 2*rs_a)->imag, + (a11 + cs_a*1 + 2*rs_a)->real, + (a11 + cs_a*1 + 2*rs_a)->imag, + (a11 + cs_a*1 + 2*rs_a)->real, + (a11 + cs_a*1 + 2*rs_a)->imag, + (a11 + cs_a*1 + 2*rs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -42501,7 +46720,14 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm9 = _mm256_sub_ps(ymm9,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*0 + 2*rs_a)); + ymm2 = _mm256_set_ps((a11 + cs_a*0 + 2*rs_a)->imag, + (a11 + cs_a*0 + 2*rs_a)->real, + (a11 + cs_a*0 + 2*rs_a)->imag, + (a11 + cs_a*0 + 2*rs_a)->real, + (a11 + cs_a*0 + 2*rs_a)->imag, + (a11 + cs_a*0 + 2*rs_a)->real, + (a11 + cs_a*0 + 2*rs_a)->imag, + (a11 + cs_a*0 + 2*rs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -42516,7 +46742,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB ymm8 = _mm256_sub_ps(ymm8,ymm16); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 1)); + ymm1 = _mm256_set_ps((d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm9) @@ -42525,7 +46754,14 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB #endif - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + rs_a) ); + ymm2 = _mm256_set_ps((a11 + cs_a*0 + 1*rs_a)->imag, + (a11 + cs_a*0 + 1*rs_a)->real, + (a11 + cs_a*0 + 1*rs_a)->imag, + (a11 + cs_a*0 + 1*rs_a)->real, + (a11 + cs_a*0 + 1*rs_a)->imag, + (a11 + cs_a*0 + 1*rs_a)->real, + (a11 + cs_a*0 + 1*rs_a)->imag, + (a11 + cs_a*0 + 1*rs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -42540,7 +46776,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB ymm8 = _mm256_sub_ps(ymm8,ymm16); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack)); + ymm1 = _mm256_set_ps((d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm8) @@ -42576,7 +46815,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB } ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 3)); + ymm1 = _mm256_set_ps((d11_pack + 3)->imag,(d11_pack + 3)->real, + (d11_pack + 3)->imag,(d11_pack + 3)->real, + (d11_pack + 3)->imag,(d11_pack + 3)->real, + (d11_pack + 3)->imag,(d11_pack + 3)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm11) @@ -42584,7 +46826,14 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB BLIS_CTRSM_MUL(ymm11) #endif - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*2 + 3*rs_a)); + ymm2 = _mm256_set_ps((a11 + cs_a*2 + 3*rs_a)->imag, + (a11 + cs_a*2 + 3*rs_a)->real, + (a11 + cs_a*2 + 3*rs_a)->imag, + (a11 + cs_a*2 + 3*rs_a)->real, + (a11 + cs_a*2 + 3*rs_a)->imag, + (a11 + cs_a*2 + 3*rs_a)->real, + (a11 + cs_a*2 + 3*rs_a)->imag, + (a11 + cs_a*2 + 3*rs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -42598,7 +46847,14 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm10 = _mm256_sub_ps(ymm10,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*1 + 3*rs_a)); + ymm2 = _mm256_set_ps((a11 + cs_a*1 + 3*rs_a)->imag, + (a11 + cs_a*1 + 3*rs_a)->real, + (a11 + cs_a*1 + 3*rs_a)->imag, + (a11 + cs_a*1 + 3*rs_a)->real, + (a11 + cs_a*1 + 3*rs_a)->imag, + (a11 + cs_a*1 + 3*rs_a)->real, + (a11 + cs_a*1 + 3*rs_a)->imag, + (a11 + cs_a*1 + 3*rs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -42612,7 +46868,14 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm9 = _mm256_sub_ps(ymm9,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*0 + 3*rs_a)); + ymm2 = _mm256_set_ps((a11 + cs_a*0 + 3*rs_a)->imag, + (a11 + cs_a*0 + 3*rs_a)->real, + (a11 + cs_a*0 + 3*rs_a)->imag, + (a11 + cs_a*0 + 3*rs_a)->real, + (a11 + cs_a*0 + 3*rs_a)->imag, + (a11 + cs_a*0 + 3*rs_a)->real, + (a11 + cs_a*0 + 3*rs_a)->imag, + (a11 + cs_a*0 + 3*rs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -42627,7 +46890,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB ymm8 = _mm256_sub_ps(ymm8,ymm16); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 2)); + ymm1 = _mm256_set_ps((d11_pack + 2)->imag,(d11_pack + 2)->real, + (d11_pack + 2)->imag,(d11_pack + 2)->real, + (d11_pack + 2)->imag,(d11_pack + 2)->real, + (d11_pack + 2)->imag,(d11_pack + 2)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm10) @@ -42636,7 +46902,14 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB #endif - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a + 2*rs_a)); + ymm2 = _mm256_set_ps((a11 + cs_a*1 + 2*rs_a)->imag, + (a11 + cs_a*1 + 2*rs_a)->real, + (a11 + cs_a*1 + 2*rs_a)->imag, + (a11 + cs_a*1 + 2*rs_a)->real, + (a11 + cs_a*1 + 2*rs_a)->imag, + (a11 + cs_a*1 + 2*rs_a)->real, + (a11 + cs_a*1 + 2*rs_a)->imag, + (a11 + cs_a*1 + 2*rs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -42650,7 +46923,14 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm9 = _mm256_sub_ps(ymm9,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*0 + 2*rs_a)); + ymm2 = _mm256_set_ps((a11 + cs_a*0 + 2*rs_a)->imag, + (a11 + cs_a*0 + 2*rs_a)->real, + (a11 + cs_a*0 + 2*rs_a)->imag, + (a11 + cs_a*0 + 2*rs_a)->real, + (a11 + cs_a*0 + 2*rs_a)->imag, + (a11 + cs_a*0 + 2*rs_a)->real, + (a11 + cs_a*0 + 2*rs_a)->imag, + (a11 + cs_a*0 + 2*rs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -42665,7 +46945,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB ymm8 = _mm256_sub_ps(ymm8,ymm16); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 1)); + ymm1 = _mm256_set_ps((d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm9) @@ -42674,7 +46957,14 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB #endif - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + rs_a) ); + ymm2 = _mm256_set_ps((a11 + cs_a*0 + 1*rs_a)->imag, + (a11 + cs_a*0 + 1*rs_a)->real, + (a11 + cs_a*0 + 1*rs_a)->imag, + (a11 + cs_a*0 + 1*rs_a)->real, + (a11 + cs_a*0 + 1*rs_a)->imag, + (a11 + cs_a*0 + 1*rs_a)->real, + (a11 + cs_a*0 + 1*rs_a)->imag, + (a11 + cs_a*0 + 1*rs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -42689,7 +46979,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB ymm8 = _mm256_sub_ps(ymm8,ymm16); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack)); + ymm1 = _mm256_set_ps((d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm8) @@ -42788,14 +47081,23 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB k_iter = (m - m_rem); BLIS_SET_S_YMM_REG_ZEROS - BLIS_CTRSM_SMALL_GEMM_4mx3n(a10,b01,cs_b,p_lda,k_iter) + BLIS_CTRSM_SMALL_GEMM_3mx3n(a10,b01,cs_b,p_lda,k_iter) ymm16 = _mm256_broadcast_ps(( __m128 const *)(&AlphaVal)); ymm16 = _mm256_permute_ps(ymm16, 0x44); - ymm0 = _mm256_loadu_ps((float const *)(b11)); - ymm1 = _mm256_loadu_ps((float const *)(b11 + cs_b *1)); - ymm2 = _mm256_loadu_ps((float const *)(b11 + cs_b *2)); + xmm0 = _mm_loadu_ps((float const *)(b11)); + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0); + xmm0 = _mm_loadl_pi(xmm0,(__m64 *)(b11 + 2)); + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 1); + xmm0 = _mm_loadu_ps((float const *)(b11 + cs_b)); + ymm1 = _mm256_insertf128_ps(ymm1, xmm0, 0); + xmm0 = _mm_loadl_pi(xmm0,(__m64 *)(b11 + cs_b + 2)); + ymm1 = _mm256_insertf128_ps(ymm1, xmm0, 1); + xmm0 = _mm_loadu_ps((float const *)(b11 + cs_b*2)); + ymm2 = _mm256_insertf128_ps(ymm2, xmm0, 0); + xmm0 = _mm_loadl_pi(xmm0,(__m64 *)(b11 + cs_b*2 + 2)); + ymm2 = _mm256_insertf128_ps(ymm2, xmm0, 1); ymm17 = _mm256_shuffle_ps(ymm16, ymm16, 0x11); @@ -42858,7 +47160,7 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB if(2 == n_rem) { ///GEMM code begins/// - BLIS_CTRSM_SMALL_GEMM_4mx2n(a10,b01,cs_b, + BLIS_CTRSM_SMALL_GEMM_3mx2n(a10,b01,cs_b, p_lda,k_iter) BLIS_PRE_CTRSM_SMALL_3M_2N(AlphaVal,b11,cs_b) @@ -42875,7 +47177,7 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB else if(1 == n_rem) { ///GEMM code begins/// - BLIS_CTRSM_SMALL_GEMM_4mx1n(a10,b01,cs_b, + BLIS_CTRSM_SMALL_GEMM_3mx1n(a10,b01,cs_b, p_lda,k_iter) BLIS_PRE_CTRSM_SMALL_3M_1N(AlphaVal,b11,cs_b) @@ -42950,15 +47252,17 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB k_iter = (m - m_rem); BLIS_SET_S_YMM_REG_ZEROS - BLIS_CTRSM_SMALL_GEMM_4mx3n(a10,b01,cs_b,p_lda,k_iter) + BLIS_CTRSM_SMALL_GEMM_2mx3n(a10,b01,cs_b,p_lda,k_iter) ymm16 = _mm256_broadcast_ps(( __m128 const *)(&AlphaVal)); ymm16 = _mm256_permute_ps(ymm16, 0x44); - ymm0 = _mm256_loadu_ps((float const *)(b11)); - ymm1 = _mm256_loadu_ps((float const *)(b11 + cs_b *1)); - ymm2 = _mm256_loadu_ps((float const *)(b11 + cs_b *2)); - + xmm0 = _mm_loadu_ps((float const *)(b11)); + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0); + xmm0 = _mm_loadu_ps((float const *)(b11 + cs_b)); + ymm1 = _mm256_insertf128_ps(ymm1, xmm0, 0); + xmm0 = _mm_loadu_ps((float const *)(b11 + cs_b*2)); + ymm2 = _mm256_insertf128_ps(ymm2, xmm0, 0); ymm17 = _mm256_shuffle_ps(ymm16, ymm16, 0x11); @@ -43014,7 +47318,7 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB if(2 == n_rem) { ///GEMM code begins/// - BLIS_CTRSM_SMALL_GEMM_4mx2n(a10,b01,cs_b, + BLIS_CTRSM_SMALL_GEMM_2mx2n(a10,b01,cs_b, p_lda,k_iter) BLIS_PRE_CTRSM_SMALL_2M_2N(AlphaVal,b11,cs_b) @@ -43031,7 +47335,7 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB else if(1 == n_rem) { ///GEMM code begins/// - BLIS_CTRSM_SMALL_GEMM_4mx1n(a10,b01,cs_b, + BLIS_CTRSM_SMALL_GEMM_2mx1n(a10,b01,cs_b, p_lda,k_iter) BLIS_PRE_CTRSM_SMALL_2M_1N(AlphaVal,b11,cs_b) @@ -43104,15 +47408,17 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB k_iter = (m - m_rem); BLIS_SET_S_YMM_REG_ZEROS - BLIS_CTRSM_SMALL_GEMM_4mx3n(a10,b01,cs_b,p_lda,k_iter) + BLIS_CTRSM_SMALL_GEMM_1mx3n(a10,b01,cs_b,p_lda,k_iter) ymm16 = _mm256_broadcast_ps(( __m128 const *)(&AlphaVal)); ymm16 = _mm256_permute_ps(ymm16, 0x44); - - ymm0 = _mm256_loadu_ps((float const *)(b11)); - ymm1 = _mm256_loadu_ps((float const *)(b11 + cs_b *1)); - ymm2 = _mm256_loadu_ps((float const *)(b11 + cs_b *2)); - + + xmm0 = _mm_loadl_pi(xmm0,(__m64 *)(b11)); + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0); + xmm0 = _mm_loadl_pi(xmm0,(__m64 *)(b11 + cs_b)); + ymm1 = _mm256_insertf128_ps(ymm1, xmm0, 0); + xmm0 = _mm_loadl_pi(xmm0,(__m64 *)(b11 + cs_b*2)); + ymm2 = _mm256_insertf128_ps(ymm2, xmm0, 0); ymm17 = _mm256_shuffle_ps(ymm16, ymm16, 0x11); @@ -43169,7 +47475,7 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB if(2 == n_rem) { ///GEMM code begins/// - BLIS_CTRSM_SMALL_GEMM_4mx2n(a10,b01,cs_b, + BLIS_CTRSM_SMALL_GEMM_1mx2n(a10,b01,cs_b, p_lda,k_iter) BLIS_PRE_CTRSM_SMALL_1M_2N(AlphaVal,b11, @@ -43189,7 +47495,7 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB else if(1 == n_rem) { ///GEMM code begins/// - BLIS_CTRSM_SMALL_GEMM_4mx1n(a10,b01,cs_b, + BLIS_CTRSM_SMALL_GEMM_1mx1n(a10,b01,cs_b, p_lda,k_iter) BLIS_PRE_CTRSM_SMALL_1M_1N(AlphaVal,b11, @@ -43219,7 +47525,6 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB return BLIS_SUCCESS; } - BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB ( obj_t* AlphaObj, @@ -43254,9 +47559,12 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB dim_t i, j, k; //loop variables dim_t k_iter; //number of times GEMM to be performed - scomplex AlphaVal = *(scomplex *)AlphaObj->buffer; //value of alpha - scomplex *L = a->buffer; //pointer to matrix A - scomplex *B = b->buffer; //pointer to matrix B + scomplex AlphaVal[2]; + AlphaVal[0] = *(scomplex *)AlphaObj->buffer; //value of alpha + AlphaVal[1] = *(scomplex *)AlphaObj->buffer; //value of alpha + + scomplex *L = bli_obj_buffer_at_off(a); //pointer to matrix A + scomplex *B = bli_obj_buffer_at_off(b); //pointer to matrix B scomplex *a01, *a11, *b10, *b11; //pointers that point to blocks for GEMM and TRSM @@ -43271,6 +47579,12 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB __m256 ymm16, ymm17, ymm18, ymm19; __m128 xmm0, xmm1, xmm2; + __m128 xmm5; + + xmm0 = _mm_setzero_ps(); + xmm1 = _mm_setzero_ps(); + xmm2 = _mm_setzero_ps(); + xmm5 = _mm_setzero_ps(); gint_t required_packing_A = 1; mem_t local_mem_buf_A_s = {0}; @@ -43387,7 +47701,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB */ ////extract a00 ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 2)); + ymm1 = _mm256_set_ps((d11_pack + 2)->imag,(d11_pack + 2)->real, + (d11_pack + 2)->imag,(d11_pack + 2)->real, + (d11_pack + 2)->imag,(d11_pack + 2)->real, + (d11_pack + 2)->imag,(d11_pack + 2)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_TWO_DIV(ymm12, ymm13) @@ -43395,7 +47712,14 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB BLIS_CTRSM_MUL(ymm12) BLIS_CTRSM_MUL(ymm13) #endif - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a *2 + rs_a*1) ); + ymm2 = _mm256_set_ps((a11 + cs_a *2 + rs_a*1)->imag, + (a11 + cs_a *2 + rs_a*1)->real, + (a11 + cs_a *2 + rs_a*1)->imag, + (a11 + cs_a *2 + rs_a*1)->real, + (a11 + cs_a *2 + rs_a*1)->imag, + (a11 + cs_a *2 + rs_a*1)->real, + (a11 + cs_a *2 + rs_a*1)->imag, + (a11 + cs_a *2 + rs_a*1)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -43419,7 +47743,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm11 = _mm256_sub_ps(ymm11,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*2) ); + ymm2 = _mm256_set_ps((a11 + cs_a *2)->imag,(a11 + cs_a *2)->real, + (a11 + cs_a *2)->imag,(a11 + cs_a *2)->real, + (a11 + cs_a *2)->imag,(a11 + cs_a *2)->real, + (a11 + cs_a *2)->imag,(a11 + cs_a *2)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -43440,7 +47767,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB ymm9 = _mm256_sub_ps(ymm9,ymm16); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 1)); + ymm1 = _mm256_set_ps((d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_TWO_DIV(ymm10, ymm11) @@ -43449,7 +47779,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB BLIS_CTRSM_MUL(ymm11) #endif - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a) ); + ymm2 = _mm256_set_ps((a11 + cs_a)->imag,(a11 + cs_a)->real, + (a11 + cs_a)->imag,(a11 + cs_a)->real, + (a11 + cs_a)->imag,(a11 + cs_a)->real, + (a11 + cs_a)->imag,(a11 + cs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -43469,7 +47802,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm9 = _mm256_sub_ps(ymm9,ymm16); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack)); + ymm1 = _mm256_set_ps((d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION @@ -43506,14 +47842,24 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB // Load b11 of size 4x6 and multiply with alpha BLIS_PRE_CTRSM_SMALL_3x4(AlphaVal,b11,cs_b) ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 2)); + ymm1 = _mm256_set_ps((d11_pack + 2)->imag,(d11_pack + 2)->real, + (d11_pack + 2)->imag,(d11_pack + 2)->real, + (d11_pack + 2)->imag,(d11_pack + 2)->real, + (d11_pack + 2)->imag,(d11_pack + 2)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm12) #else BLIS_CTRSM_MUL(ymm12) #endif - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a *2 + rs_a*1) ); + ymm2 = _mm256_set_ps((a11 + cs_a *2 + rs_a*1)->imag, + (a11 + cs_a *2 + rs_a*1)->real, + (a11 + cs_a *2 + rs_a*1)->imag, + (a11 + cs_a *2 + rs_a*1)->real, + (a11 + cs_a *2 + rs_a*1)->imag, + (a11 + cs_a *2 + rs_a*1)->real, + (a11 + cs_a *2 + rs_a*1)->imag, + (a11 + cs_a *2 + rs_a*1)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -43531,7 +47877,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm10 = _mm256_sub_ps(ymm10,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*2) ); + ymm2 = _mm256_set_ps((a11 + cs_a *2)->imag,(a11 + cs_a *2)->real, + (a11 + cs_a *2)->imag,(a11 + cs_a *2)->real, + (a11 + cs_a *2)->imag,(a11 + cs_a *2)->real, + (a11 + cs_a *2)->imag,(a11 + cs_a *2)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -43545,7 +47894,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm8 = _mm256_sub_ps(ymm8,ymm16); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 1)); + ymm1 = _mm256_set_ps((d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm10) @@ -43553,7 +47905,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB BLIS_CTRSM_MUL(ymm10) #endif - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a) ); + ymm2 = _mm256_set_ps((a11 + cs_a)->imag,(a11 + cs_a)->real, + (a11 + cs_a)->imag,(a11 + cs_a)->real, + (a11 + cs_a)->imag,(a11 + cs_a)->real, + (a11 + cs_a)->imag,(a11 + cs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -43567,7 +47922,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm8 = _mm256_sub_ps(ymm8,ymm16); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack)); + ymm1 = _mm256_set_ps((d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION @@ -43598,20 +47956,30 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB BLIS_SET_S_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_CTRSM_SMALL_GEMM_3nx4m(a01,b10,cs_b,p_lda,k_iter) + BLIS_CTRSM_SMALL_GEMM_3nx3m(a01,b10,cs_b,p_lda,k_iter) // Load b11 of size 4x6 and multiply with alpha - BLIS_PRE_CTRSM_SMALL_3x4(AlphaVal,b11,cs_b) + BLIS_PRE_CTRSM_SMALL_3x3(AlphaVal,b11,cs_b) ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 2)); + ymm1 = _mm256_set_ps((d11_pack + 2)->imag,(d11_pack + 2)->real, + (d11_pack + 2)->imag,(d11_pack + 2)->real, + (d11_pack + 2)->imag,(d11_pack + 2)->real, + (d11_pack + 2)->imag,(d11_pack + 2)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm12) #else BLIS_CTRSM_MUL(ymm12) #endif - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a *2 + rs_a*1) ); + ymm2 = _mm256_set_ps((a11 + cs_a *2 + rs_a*1)->imag, + (a11 + cs_a *2 + rs_a*1)->real, + (a11 + cs_a *2 + rs_a*1)->imag, + (a11 + cs_a *2 + rs_a*1)->real, + (a11 + cs_a *2 + rs_a*1)->imag, + (a11 + cs_a *2 + rs_a*1)->real, + (a11 + cs_a *2 + rs_a*1)->imag, + (a11 + cs_a *2 + rs_a*1)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -43628,7 +47996,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm10 = _mm256_sub_ps(ymm10,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*2) ); + ymm2 = _mm256_set_ps((a11 + cs_a *2)->imag,(a11 + cs_a *2)->real, + (a11 + cs_a *2)->imag,(a11 + cs_a *2)->real, + (a11 + cs_a *2)->imag,(a11 + cs_a *2)->real, + (a11 + cs_a *2)->imag,(a11 + cs_a *2)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -43642,7 +48013,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm8 = _mm256_sub_ps(ymm8,ymm16); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 1)); + ymm1 = _mm256_set_ps((d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm10) @@ -43650,7 +48024,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB BLIS_CTRSM_MUL(ymm10) #endif - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a) ); + ymm2 = _mm256_set_ps((a11 + cs_a)->imag,(a11 + cs_a)->real, + (a11 + cs_a)->imag,(a11 + cs_a)->real, + (a11 + cs_a)->imag,(a11 + cs_a)->real, + (a11 + cs_a)->imag,(a11 + cs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -43664,7 +48041,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm8 = _mm256_sub_ps(ymm8,ymm16); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack)); + ymm1 = _mm256_set_ps((d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION @@ -43706,20 +48086,30 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB BLIS_SET_S_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_CTRSM_SMALL_GEMM_3nx4m(a01,b10,cs_b,p_lda,k_iter) + BLIS_CTRSM_SMALL_GEMM_3nx2m(a01,b10,cs_b,p_lda,k_iter) // Load b11 of size 4x6 and multiply with alpha - BLIS_PRE_CTRSM_SMALL_3x4(AlphaVal,b11,cs_b) + BLIS_PRE_CTRSM_SMALL_3x2(AlphaVal,b11,cs_b) ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 2)); + ymm1 = _mm256_set_ps((d11_pack + 2)->imag,(d11_pack + 2)->real, + (d11_pack + 2)->imag,(d11_pack + 2)->real, + (d11_pack + 2)->imag,(d11_pack + 2)->real, + (d11_pack + 2)->imag,(d11_pack + 2)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm12) #else BLIS_CTRSM_MUL(ymm12) #endif - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a *2 + rs_a*1) ); + ymm2 = _mm256_set_ps((a11 + cs_a *2 + rs_a*1)->imag, + (a11 + cs_a *2 + rs_a*1)->real, + (a11 + cs_a *2 + rs_a*1)->imag, + (a11 + cs_a *2 + rs_a*1)->real, + (a11 + cs_a *2 + rs_a*1)->imag, + (a11 + cs_a *2 + rs_a*1)->real, + (a11 + cs_a *2 + rs_a*1)->imag, + (a11 + cs_a *2 + rs_a*1)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -43737,7 +48127,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm10 = _mm256_sub_ps(ymm10,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*2) ); + ymm2 = _mm256_set_ps((a11 + cs_a *2)->imag,(a11 + cs_a *2)->real, + (a11 + cs_a *2)->imag,(a11 + cs_a *2)->real, + (a11 + cs_a *2)->imag,(a11 + cs_a *2)->real, + (a11 + cs_a *2)->imag,(a11 + cs_a *2)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -43751,7 +48144,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm8 = _mm256_sub_ps(ymm8,ymm16); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 1)); + ymm1 = _mm256_set_ps((d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm10) @@ -43759,7 +48155,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB BLIS_CTRSM_MUL(ymm10) #endif - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a) ); + ymm2 = _mm256_set_ps((a11 + cs_a)->imag,(a11 + cs_a)->real, + (a11 + cs_a)->imag,(a11 + cs_a)->real, + (a11 + cs_a)->imag,(a11 + cs_a)->real, + (a11 + cs_a)->imag,(a11 + cs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -43773,7 +48172,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm8 = _mm256_sub_ps(ymm8,ymm16); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack)); + ymm1 = _mm256_set_ps((d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION @@ -43805,20 +48207,30 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB BLIS_SET_S_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_CTRSM_SMALL_GEMM_3nx4m(a01,b10,cs_b,p_lda,k_iter) + BLIS_CTRSM_SMALL_GEMM_3nx1m(a01,b10,cs_b,p_lda,k_iter) // Load b11 of size 4x6 and multiply with alpha - BLIS_PRE_CTRSM_SMALL_3x4(AlphaVal,b11,cs_b) + BLIS_PRE_CTRSM_SMALL_3x1(AlphaVal,b11,cs_b) ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 2)); + ymm1 = _mm256_set_ps((d11_pack + 2)->imag,(d11_pack + 2)->real, + (d11_pack + 2)->imag,(d11_pack + 2)->real, + (d11_pack + 2)->imag,(d11_pack + 2)->real, + (d11_pack + 2)->imag,(d11_pack + 2)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm12) #else BLIS_CTRSM_MUL(ymm12) #endif - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a *2 + rs_a*1) ); + ymm2 = _mm256_set_ps((a11 + cs_a *2 + rs_a*1)->imag, + (a11 + cs_a *2 + rs_a*1)->real, + (a11 + cs_a *2 + rs_a*1)->imag, + (a11 + cs_a *2 + rs_a*1)->real, + (a11 + cs_a *2 + rs_a*1)->imag, + (a11 + cs_a *2 + rs_a*1)->real, + (a11 + cs_a *2 + rs_a*1)->imag, + (a11 + cs_a *2 + rs_a*1)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -43835,7 +48247,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm10 = _mm256_sub_ps(ymm10,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*2) ); + ymm2 = _mm256_set_ps((a11 + cs_a *2)->imag,(a11 + cs_a *2)->real, + (a11 + cs_a *2)->imag,(a11 + cs_a *2)->real, + (a11 + cs_a *2)->imag,(a11 + cs_a *2)->real, + (a11 + cs_a *2)->imag,(a11 + cs_a *2)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -43849,7 +48264,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm8 = _mm256_sub_ps(ymm8,ymm16); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 1)); + ymm1 = _mm256_set_ps((d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm10) @@ -43857,7 +48275,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB BLIS_CTRSM_MUL(ymm10) #endif - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a) ); + ymm2 = _mm256_set_ps((a11 + cs_a)->imag,(a11 + cs_a)->real, + (a11 + cs_a)->imag,(a11 + cs_a)->real, + (a11 + cs_a)->imag,(a11 + cs_a)->real, + (a11 + cs_a)->imag,(a11 + cs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -43871,7 +48292,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm8 = _mm256_sub_ps(ymm8,ymm16); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack)); + ymm1 = _mm256_set_ps((d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION @@ -44025,7 +48449,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB ymm11 = _mm256_sub_ps(ymm19, ymm11); ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 1)); + ymm1 = _mm256_set_ps((d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_TWO_DIV(ymm10, ymm11) @@ -44034,7 +48461,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB BLIS_CTRSM_MUL(ymm11) #endif - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a) ); + ymm2 = _mm256_set_ps((a11 + cs_a)->imag,(a11 + cs_a)->real, + (a11 + cs_a)->imag,(a11 + cs_a)->real, + (a11 + cs_a)->imag,(a11 + cs_a)->real, + (a11 + cs_a)->imag,(a11 + cs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -44054,7 +48484,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm9 = _mm256_sub_ps(ymm9,ymm16); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack)); + ymm1 = _mm256_set_ps((d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION @@ -44104,14 +48537,20 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB ymm10 = _mm256_sub_ps(ymm19, ymm10); ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 1)); + ymm1 = _mm256_set_ps((d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm10) #else BLIS_CTRSM_MUL(ymm10) #endif - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*1) ); + ymm2 = _mm256_set_ps((a11 + cs_a)->imag,(a11 + cs_a)->real, + (a11 + cs_a)->imag,(a11 + cs_a)->real, + (a11 + cs_a)->imag,(a11 + cs_a)->real, + (a11 + cs_a)->imag,(a11 + cs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -44127,7 +48566,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm8 = _mm256_sub_ps(ymm8,ymm16); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack)); + ymm1 = _mm256_set_ps((d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION @@ -44152,11 +48594,14 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB BLIS_SET_S_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_CTRSM_SMALL_GEMM_2nx4m(a01,b10,cs_b,p_lda,k_iter) + BLIS_CTRSM_SMALL_GEMM_2nx3m(a01,b10,cs_b,p_lda,k_iter) ymm16 = _mm256_broadcast_ps(( __m128 const *)(&AlphaVal)); ymm16 = _mm256_permute_ps(ymm16, 0x44); - ymm0 = _mm256_loadu_ps((float const *)(b11)); + xmm0 = _mm_loadu_ps((float const *)(b11)); + xmm1 = _mm_loadl_pi(xmm1, (__m64 const *)(b11 + 2)); + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0); + ymm0 = _mm256_insertf128_ps(ymm0, xmm1, 1); ymm17 = _mm256_shuffle_ps(ymm16, ymm16, 0x11); ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0); @@ -44165,7 +48610,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB ymm19 = _mm256_fmaddsub_ps(ymm18, ymm16, ymm19); ymm8 = _mm256_sub_ps(ymm19, ymm8); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b * 1)); + xmm0 = _mm_loadu_ps((float const *)(b11 + cs_b)); + xmm1 = _mm_loadl_pi(xmm1, (__m64 const *)(b11 + cs_b + 2)); + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0); + ymm0 = _mm256_insertf128_ps(ymm0, xmm1, 1); ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0); ymm19 = _mm256_shuffle_ps(ymm0, ymm0,0xF5); @@ -44174,14 +48622,20 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB ymm10 = _mm256_sub_ps(ymm19, ymm10); ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 1)); + ymm1 = _mm256_set_ps((d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm10) #else BLIS_CTRSM_MUL(ymm10) #endif - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*1) ); + ymm2 = _mm256_set_ps((a11 + cs_a)->imag,(a11 + cs_a)->real, + (a11 + cs_a)->imag,(a11 + cs_a)->real, + (a11 + cs_a)->imag,(a11 + cs_a)->real, + (a11 + cs_a)->imag,(a11 + cs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -44197,7 +48651,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm8 = _mm256_sub_ps(ymm8,ymm16); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack)); + ymm1 = _mm256_set_ps((d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION @@ -44229,12 +48686,13 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB k_iter = (n-n_rem); BLIS_SET_S_YMM_REG_ZEROS - ///GEMM implementation starts/// - BLIS_CTRSM_SMALL_GEMM_2nx4m(a01,b10,cs_b,p_lda,k_iter) + ///GEMM implementation starts/// + BLIS_CTRSM_SMALL_GEMM_2nx2m(a01,b10,cs_b,p_lda,k_iter) ymm16 = _mm256_broadcast_ps(( __m128 const *)(&AlphaVal)); ymm16 = _mm256_permute_ps(ymm16, 0x44); - ymm0 = _mm256_loadu_ps((float const *)(b11)); + xmm0 = _mm_loadu_ps((float const *)(b11)); + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0); ymm17 = _mm256_shuffle_ps(ymm16, ymm16, 0x11); ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0); @@ -44243,7 +48701,8 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB ymm19 = _mm256_fmaddsub_ps(ymm18, ymm16, ymm19); ymm8 = _mm256_sub_ps(ymm19, ymm8); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b * 1)); + xmm0 = _mm_loadu_ps((float const *)(b11 + cs_b)); + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0); ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0); ymm19 = _mm256_shuffle_ps(ymm0, ymm0,0xF5); @@ -44252,14 +48711,20 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB ymm10 = _mm256_sub_ps(ymm19, ymm10); ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 1)); + ymm1 = _mm256_set_ps((d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm10) #else BLIS_CTRSM_MUL(ymm10) #endif - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*1) ); + ymm2 = _mm256_set_ps((a11 + cs_a)->imag,(a11 + cs_a)->real, + (a11 + cs_a)->imag,(a11 + cs_a)->real, + (a11 + cs_a)->imag,(a11 + cs_a)->real, + (a11 + cs_a)->imag,(a11 + cs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -44275,7 +48740,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm8 = _mm256_sub_ps(ymm8,ymm16); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack)); + ymm1 = _mm256_set_ps((d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION @@ -44302,11 +48770,12 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB BLIS_SET_S_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_CTRSM_SMALL_GEMM_2nx4m(a01,b10,cs_b,p_lda,k_iter) + BLIS_CTRSM_SMALL_GEMM_2nx1m(a01,b10,cs_b,p_lda,k_iter) ymm16 = _mm256_broadcast_ps(( __m128 const *)(&AlphaVal)); ymm16 = _mm256_permute_ps(ymm16, 0x44); - ymm0 = _mm256_loadu_ps((float const *)(b11)); + xmm1 = _mm_loadl_pi(xmm1, (__m64 const *)(b11)); + ymm0 = _mm256_insertf128_ps(ymm0, xmm1, 0); ymm17 = _mm256_shuffle_ps(ymm16, ymm16, 0x11); ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0); @@ -44315,7 +48784,8 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB ymm19 = _mm256_fmaddsub_ps(ymm18, ymm16, ymm19); ymm8 = _mm256_sub_ps(ymm19, ymm8); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b * 1)); + xmm1 = _mm_loadl_pi(xmm1, (__m64 const *)(b11 + cs_b)); + ymm0 = _mm256_insertf128_ps(ymm0, xmm1, 0); ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0); ymm19 = _mm256_shuffle_ps(ymm0, ymm0,0xF5); @@ -44324,14 +48794,20 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB ymm10 = _mm256_sub_ps(ymm19, ymm10); ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 1)); + ymm1 = _mm256_set_ps((d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm10) #else BLIS_CTRSM_MUL(ymm10) #endif - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*1) ); + ymm2 = _mm256_set_ps((a11 + cs_a)->imag,(a11 + cs_a)->real, + (a11 + cs_a)->imag,(a11 + cs_a)->real, + (a11 + cs_a)->imag,(a11 + cs_a)->real, + (a11 + cs_a)->imag,(a11 + cs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -44347,7 +48823,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm8 = _mm256_sub_ps(ymm8,ymm16); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack)); + ymm1 = _mm256_set_ps((d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION @@ -44424,7 +48903,6 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB _mm_storeu_ps((float *)(ptr_a10_dup + p_lda * 0 + x*3), xmm0); xmm0 = _mm_loadl_pi(xmm1,(__m64 *)(a01 + rs_a * 0 + 2 + x*3)); _mm_storel_pi((__m64 *)(ptr_a10_dup + p_lda * 0 + 2 + x*3),xmm0); - } } @@ -44470,7 +48948,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack)); + ymm1 = _mm256_set_ps((d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_TWO_DIV(ymm8, ymm9) @@ -44510,7 +48991,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack)); + ymm1 = _mm256_set_ps((d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm8) @@ -44531,12 +49015,15 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB k_iter = (n-n_rem); BLIS_SET_S_YMM_REG_ZEROS - ///GEMM implementation starts/// - BLIS_CTRSM_SMALL_GEMM_1nx4m(a01,b10,cs_b,p_lda,k_iter) + ///GEMM implementation starts/// + BLIS_CTRSM_SMALL_GEMM_1nx3m(a01,b10,cs_b,p_lda,k_iter) ymm16 = _mm256_broadcast_ps(( __m128 const *)(&AlphaVal)); ymm16 = _mm256_permute_ps(ymm16, 0x44); - ymm0 = _mm256_loadu_ps((float const *)(b11)); + xmm0 = _mm_loadu_ps((float const *)(b11)); + xmm1 = _mm_loadl_pi(xmm1, (__m64 const *)(b11 + 2)); + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0); + ymm0 = _mm256_insertf128_ps(ymm0, xmm1, 1); ymm17 = _mm256_shuffle_ps(ymm16, ymm16, 0x11); ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0); @@ -44547,7 +49034,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack)); + ymm1 = _mm256_set_ps((d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm8) @@ -44571,11 +49061,12 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB BLIS_SET_S_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_CTRSM_SMALL_GEMM_1nx4m(a01,b10,cs_b,p_lda,k_iter) + BLIS_CTRSM_SMALL_GEMM_1nx2m(a01,b10,cs_b,p_lda,k_iter) ymm16 = _mm256_broadcast_ps(( __m128 const *)(&AlphaVal)); ymm16 = _mm256_permute_ps(ymm16, 0x44); - ymm0 = _mm256_loadu_ps((float const *)(b11)); + xmm0 = _mm_loadu_ps((float const *)(b11)); + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0); ymm17 = _mm256_shuffle_ps(ymm16, ymm16, 0x11); ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0); @@ -44586,7 +49077,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack)); + ymm1 = _mm256_set_ps((d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm8) @@ -44608,11 +49102,12 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB BLIS_SET_S_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_CTRSM_SMALL_GEMM_1nx4m(a01,b10,cs_b,p_lda,k_iter) + BLIS_CTRSM_SMALL_GEMM_1nx1m(a01,b10,cs_b,p_lda,k_iter) ymm16 = _mm256_broadcast_ps(( __m128 const *)(&AlphaVal)); ymm16 = _mm256_permute_ps(ymm16, 0x44); - ymm0 = _mm256_loadu_ps((float const *)(b11)); + xmm1 = _mm_loadl_pi(xmm1, (__m64 const *)(b11)); + ymm0 = _mm256_insertf128_ps(ymm0, xmm1, 0); ymm17 = _mm256_shuffle_ps(ymm16, ymm16, 0x11); ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0); @@ -44622,7 +49117,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB ymm8 = _mm256_sub_ps(ymm19, ymm8); ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack)); + ymm1 = _mm256_set_ps((d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm8) @@ -44681,9 +49179,12 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB dim_t i, j, k; //loop variables dim_t k_iter; //number of times GEMM to be performed - scomplex AlphaVal = *(scomplex *)AlphaObj->buffer; //value of alpha - scomplex *L = a->buffer; //pointer to matrix A - scomplex *B = b->buffer; //pointer to matrix B + scomplex AlphaVal[2]; + AlphaVal[0] = *(scomplex *)AlphaObj->buffer; //value of alpha + AlphaVal[1] = *(scomplex *)AlphaObj->buffer; //value of alpha + + scomplex *L = bli_obj_buffer_at_off(a); //pointer to matrix A + scomplex *B = bli_obj_buffer_at_off(b); //pointer to matrix B scomplex *a01, *a11, *b10, *b11; //pointers that point to blocks for GEMM and TRSM @@ -44697,6 +49198,12 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB __m256 ymm16, ymm17, ymm18, ymm19; __m128 xmm0, xmm1, xmm2; + __m128 xmm5; + + xmm0 = _mm_setzero_ps(); + xmm1 = _mm_setzero_ps(); + xmm2 = _mm_setzero_ps(); + xmm5 = _mm_setzero_ps(); gint_t required_packing_A = 1; mem_t local_mem_buf_A_s = {0}; @@ -44814,7 +49321,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB */ ////extract a00 ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack)); + ymm1 = _mm256_set_ps((d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_TWO_DIV(ymm8, ymm9) @@ -44822,7 +49332,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB BLIS_CTRSM_MUL(ymm8) BLIS_CTRSM_MUL(ymm9) #endif - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + rs_a*1) ); + ymm2 = _mm256_set_ps((a11 + rs_a*1)->imag,(a11 + rs_a*1)->real, + (a11 + rs_a*1)->imag,(a11 + rs_a*1)->real, + (a11 + rs_a*1)->imag,(a11 + rs_a*1)->real, + (a11 + rs_a*1)->imag,(a11 + rs_a*1)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -44846,7 +49359,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm11 = _mm256_sub_ps(ymm11,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + rs_a*2) ); + ymm2 = _mm256_set_ps((a11 + rs_a*2)->imag,(a11 + rs_a*2)->real, + (a11 + rs_a*2)->imag,(a11 + rs_a*2)->real, + (a11 + rs_a*2)->imag,(a11 + rs_a*2)->real, + (a11 + rs_a*2)->imag,(a11 + rs_a*2)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -44868,7 +49384,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB ymm13 = _mm256_sub_ps(ymm13,ymm16); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 1)); + ymm1 = _mm256_set_ps((d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_TWO_DIV(ymm10, ymm11) @@ -44880,7 +49399,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB a11 += cs_a; - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + rs_a*2) ); + ymm2 = _mm256_set_ps((a11 + rs_a*2)->imag,(a11 + rs_a*2)->real, + (a11 + rs_a*2)->imag,(a11 + rs_a*2)->real, + (a11 + rs_a*2)->imag,(a11 + rs_a*2)->real, + (a11 + rs_a*2)->imag,(a11 + rs_a*2)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -44901,7 +49423,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm13 = _mm256_sub_ps(ymm13,ymm16); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 2)); + ymm1 = _mm256_set_ps((d11_pack + 2)->imag,(d11_pack + 2)->real, + (d11_pack + 2)->imag,(d11_pack + 2)->real, + (d11_pack + 2)->imag,(d11_pack + 2)->real, + (d11_pack + 2)->imag,(d11_pack + 2)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION @@ -44940,14 +49465,20 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB BLIS_PRE_CTRSM_SMALL_3x4(AlphaVal,b11,cs_b) ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack)); + ymm1 = _mm256_set_ps((d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm8) #else BLIS_CTRSM_MUL(ymm8) #endif - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + rs_a*1) ); + ymm2 = _mm256_set_ps((a11 + rs_a*1)->imag,(a11 + rs_a*1)->real, + (a11 + rs_a*1)->imag,(a11 + rs_a*1)->real, + (a11 + rs_a*1)->imag,(a11 + rs_a*1)->real, + (a11 + rs_a*1)->imag,(a11 + rs_a*1)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -44964,7 +49495,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm10 = _mm256_sub_ps(ymm10,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + rs_a*2) ); + ymm2 = _mm256_set_ps((a11 + rs_a*2)->imag,(a11 + rs_a*2)->real, + (a11 + rs_a*2)->imag,(a11 + rs_a*2)->real, + (a11 + rs_a*2)->imag,(a11 + rs_a*2)->real, + (a11 + rs_a*2)->imag,(a11 + rs_a*2)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -44980,7 +49514,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB ymm12 = _mm256_sub_ps(ymm12,ymm16); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 1)); + ymm1 = _mm256_set_ps((d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm10) @@ -44991,7 +49528,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB a11 += cs_a; - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + rs_a*2) ); + ymm2 = _mm256_set_ps((a11 + rs_a*2)->imag,(a11 + rs_a*2)->real, + (a11 + rs_a*2)->imag,(a11 + rs_a*2)->real, + (a11 + rs_a*2)->imag,(a11 + rs_a*2)->real, + (a11 + rs_a*2)->imag,(a11 + rs_a*2)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -45007,7 +49547,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB ymm12 = _mm256_sub_ps(ymm12,ymm16); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 2)); + ymm1 = _mm256_set_ps((d11_pack + 2)->imag,(d11_pack + 2)->real, + (d11_pack + 2)->imag,(d11_pack + 2)->real, + (d11_pack + 2)->imag,(d11_pack + 2)->real, + (d11_pack + 2)->imag,(d11_pack + 2)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION @@ -45036,20 +49579,26 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB BLIS_SET_S_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_CTRSM_SMALL_GEMM_3nx4m(a01,b10,cs_b,p_lda,k_iter) + BLIS_CTRSM_SMALL_GEMM_3nx3m(a01,b10,cs_b,p_lda,k_iter) // Load b11 of size 4x6 and multiply with alpha - BLIS_PRE_CTRSM_SMALL_3x4(AlphaVal,b11,cs_b) + BLIS_PRE_CTRSM_SMALL_3x3(AlphaVal,b11,cs_b) ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack)); + ymm1 = _mm256_set_ps((d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm8) #else BLIS_CTRSM_MUL(ymm8) #endif - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + rs_a*1) ); + ymm2 = _mm256_set_ps((a11 + rs_a*1)->imag,(a11 + rs_a*1)->real, + (a11 + rs_a*1)->imag,(a11 + rs_a*1)->real, + (a11 + rs_a*1)->imag,(a11 + rs_a*1)->real, + (a11 + rs_a*1)->imag,(a11 + rs_a*1)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -45066,7 +49615,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm10 = _mm256_sub_ps(ymm10,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + rs_a*2) ); + ymm2 = _mm256_set_ps((a11 + rs_a*2)->imag,(a11 + rs_a*2)->real, + (a11 + rs_a*2)->imag,(a11 + rs_a*2)->real, + (a11 + rs_a*2)->imag,(a11 + rs_a*2)->real, + (a11 + rs_a*2)->imag,(a11 + rs_a*2)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -45082,7 +49634,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB ymm12 = _mm256_sub_ps(ymm12,ymm16); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 1)); + ymm1 = _mm256_set_ps((d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm10) @@ -45093,7 +49648,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB a11 += cs_a; - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + rs_a*2) ); + ymm2 = _mm256_set_ps((a11 + rs_a*2)->imag,(a11 + rs_a*2)->real, + (a11 + rs_a*2)->imag,(a11 + rs_a*2)->real, + (a11 + rs_a*2)->imag,(a11 + rs_a*2)->real, + (a11 + rs_a*2)->imag,(a11 + rs_a*2)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -45109,7 +49667,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB ymm12 = _mm256_sub_ps(ymm12,ymm16); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 2)); + ymm1 = _mm256_set_ps((d11_pack + 2)->imag,(d11_pack + 2)->real, + (d11_pack + 2)->imag,(d11_pack + 2)->real, + (d11_pack + 2)->imag,(d11_pack + 2)->real, + (d11_pack + 2)->imag,(d11_pack + 2)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION @@ -45118,15 +49679,6 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB BLIS_CTRSM_MUL(ymm12) #endif -/* ymm0 = _mm256_loadu_ps((float const *)b11); - ymm1 = _mm256_loadu_ps((float const *)(b11 + cs_b)); - ymm2 = _mm256_loadu_ps((float const *)(b11 + cs_b * 2)); - ymm8 = _mm256_blend_ps(ymm8, ymm0, 0xC0); - ymm10 = _mm256_blend_ps(ymm10, ymm1, 0xC0); - ymm12 = _mm256_blend_ps(ymm12, ymm2, 0xC0); - _mm256_storeu_ps((float *)b11, ymm8); - _mm256_storeu_ps((float *)(b11 + cs_b), ymm10); - _mm256_storeu_ps((float *)(b11 + cs_b * 2), ymm12);*/ xmm0 = _mm256_extractf128_ps(ymm8, 0); xmm1 = _mm256_extractf128_ps(ymm8, 1); _mm_storeu_ps((float *)(b11), xmm0); @@ -45158,20 +49710,26 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB BLIS_SET_S_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_CTRSM_SMALL_GEMM_3nx4m(a01,b10,cs_b,p_lda,k_iter) + BLIS_CTRSM_SMALL_GEMM_3nx2m(a01,b10,cs_b,p_lda,k_iter) // Load b11 of size 4x6 and multiply with alpha - BLIS_PRE_CTRSM_SMALL_3x4(AlphaVal,b11,cs_b) + BLIS_PRE_CTRSM_SMALL_3x2(AlphaVal,b11,cs_b) ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack)); + ymm1 = _mm256_set_ps((d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm8) #else BLIS_CTRSM_MUL(ymm8) #endif - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + rs_a*1) ); + ymm2 = _mm256_set_ps((a11 + rs_a*1)->imag,(a11 + rs_a*1)->real, + (a11 + rs_a*1)->imag,(a11 + rs_a*1)->real, + (a11 + rs_a*1)->imag,(a11 + rs_a*1)->real, + (a11 + rs_a*1)->imag,(a11 + rs_a*1)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -45188,7 +49746,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm10 = _mm256_sub_ps(ymm10,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + rs_a*2) ); + ymm2 = _mm256_set_ps((a11 + rs_a*2)->imag,(a11 + rs_a*2)->real, + (a11 + rs_a*2)->imag,(a11 + rs_a*2)->real, + (a11 + rs_a*2)->imag,(a11 + rs_a*2)->real, + (a11 + rs_a*2)->imag,(a11 + rs_a*2)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -45204,7 +49765,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB ymm12 = _mm256_sub_ps(ymm12,ymm16); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 1)); + ymm1 = _mm256_set_ps((d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm10) @@ -45215,7 +49779,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB a11 += cs_a; - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + rs_a*2) ); + ymm2 = _mm256_set_ps((a11 + rs_a*2)->imag,(a11 + rs_a*2)->real, + (a11 + rs_a*2)->imag,(a11 + rs_a*2)->real, + (a11 + rs_a*2)->imag,(a11 + rs_a*2)->real, + (a11 + rs_a*2)->imag,(a11 + rs_a*2)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -45231,7 +49798,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB ymm12 = _mm256_sub_ps(ymm12,ymm16); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 2)); + ymm1 = _mm256_set_ps((d11_pack + 2)->imag,(d11_pack + 2)->real, + (d11_pack + 2)->imag,(d11_pack + 2)->real, + (d11_pack + 2)->imag,(d11_pack + 2)->real, + (d11_pack + 2)->imag,(d11_pack + 2)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION @@ -45240,16 +49810,6 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB BLIS_CTRSM_MUL(ymm12) #endif -/* ymm0 = _mm256_loadu_ps((float const *)b11); - ymm1 = _mm256_loadu_ps((float const *)(b11 + cs_b)); - ymm2 = _mm256_loadu_ps((float const *)(b11 + cs_b * 2)); - ymm8 = _mm256_blend_ps(ymm8, ymm0, 0xF0); - ymm10 = _mm256_blend_ps(ymm10, ymm1, 0xF0); - ymm12 = _mm256_blend_ps(ymm12, ymm2, 0xF0); - _mm256_storeu_ps((float *)b11, ymm8); - _mm256_storeu_ps((float *)(b11 + cs_b), ymm10); - _mm256_storeu_ps((float *)(b11 + cs_b * 2), ymm12); -*/ xmm0 = _mm256_extractf128_ps(ymm8, 0); _mm_storeu_ps((float *)(b11), xmm0); xmm0 = _mm256_extractf128_ps(ymm10, 0); @@ -45274,20 +49834,26 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB BLIS_SET_S_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_CTRSM_SMALL_GEMM_3nx4m(a01,b10,cs_b,p_lda,k_iter) + BLIS_CTRSM_SMALL_GEMM_3nx1m(a01,b10,cs_b,p_lda,k_iter) // Load b11 of size 4x6 and multiply with alpha - BLIS_PRE_CTRSM_SMALL_3x4(AlphaVal,b11,cs_b) + BLIS_PRE_CTRSM_SMALL_3x1(AlphaVal,b11,cs_b) ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack)); + ymm1 = _mm256_set_ps((d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm8) #else BLIS_CTRSM_MUL(ymm8) #endif - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + rs_a*1) ); + ymm2 = _mm256_set_ps((a11 + rs_a*1)->imag,(a11 + rs_a*1)->real, + (a11 + rs_a*1)->imag,(a11 + rs_a*1)->real, + (a11 + rs_a*1)->imag,(a11 + rs_a*1)->real, + (a11 + rs_a*1)->imag,(a11 + rs_a*1)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -45304,7 +49870,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm10 = _mm256_sub_ps(ymm10,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + rs_a*2) ); + ymm2 = _mm256_set_ps((a11 + rs_a*2)->imag,(a11 + rs_a*2)->real, + (a11 + rs_a*2)->imag,(a11 + rs_a*2)->real, + (a11 + rs_a*2)->imag,(a11 + rs_a*2)->real, + (a11 + rs_a*2)->imag,(a11 + rs_a*2)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -45320,7 +49889,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB ymm12 = _mm256_sub_ps(ymm12,ymm16); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 1)); + ymm1 = _mm256_set_ps((d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm10) @@ -45331,7 +49903,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB a11 += cs_a; - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + rs_a*2) ); + ymm2 = _mm256_set_ps((a11 + rs_a*2)->imag,(a11 + rs_a*2)->real, + (a11 + rs_a*2)->imag,(a11 + rs_a*2)->real, + (a11 + rs_a*2)->imag,(a11 + rs_a*2)->real, + (a11 + rs_a*2)->imag,(a11 + rs_a*2)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -45347,7 +49922,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB ymm12 = _mm256_sub_ps(ymm12,ymm16); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 2)); + ymm1 = _mm256_set_ps((d11_pack + 2)->imag,(d11_pack + 2)->real, + (d11_pack + 2)->imag,(d11_pack + 2)->real, + (d11_pack + 2)->imag,(d11_pack + 2)->real, + (d11_pack + 2)->imag,(d11_pack + 2)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION @@ -45356,16 +49934,6 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB BLIS_CTRSM_MUL(ymm12) #endif -/* ymm0 = _mm256_loadu_ps((float const *)b11); - ymm1 = _mm256_loadu_ps((float const *)(b11 + cs_b)); - ymm2 = _mm256_loadu_ps((float const *)(b11 + cs_b * 2)); - ymm8 = _mm256_blend_ps(ymm8, ymm0, 0xFC); - ymm10 = _mm256_blend_ps(ymm10, ymm1, 0xFC); - ymm12 = _mm256_blend_ps(ymm12, ymm2, 0xFC); - _mm256_storeu_ps((float *)b11, ymm8); - _mm256_storeu_ps((float *)(b11 + cs_b), ymm10); - _mm256_storeu_ps((float *)(b11 + cs_b * 2), ymm12); -*/ xmm0 = _mm256_extractf128_ps(ymm8, 0); xmm1 = _mm256_extractf128_ps(ymm10, 0); xmm2 = _mm256_extractf128_ps(ymm12, 0); @@ -45508,7 +50076,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack)); + ymm1 = _mm256_set_ps((d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_TWO_DIV(ymm8, ymm9) @@ -45516,7 +50087,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB BLIS_CTRSM_MUL(ymm8) BLIS_CTRSM_MUL(ymm9) #endif - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + rs_a*1) ); + ymm2 = _mm256_set_ps((a11 + rs_a*1)->imag,(a11 + rs_a*1)->real, + (a11 + rs_a*1)->imag,(a11 + rs_a*1)->real, + (a11 + rs_a*1)->imag,(a11 + rs_a*1)->real, + (a11 + rs_a*1)->imag,(a11 + rs_a*1)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -45539,7 +50113,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm11 = _mm256_sub_ps(ymm11,ymm16); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 1)); + ymm1 = _mm256_set_ps((d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION @@ -45589,14 +50166,20 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB ymm10 = _mm256_sub_ps(ymm19, ymm10); ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack)); + ymm1 = _mm256_set_ps((d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm8) #else BLIS_CTRSM_MUL(ymm8) #endif - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + rs_a*1) ); + ymm2 = _mm256_set_ps((a11 + rs_a*1)->imag,(a11 + rs_a*1)->real, + (a11 + rs_a*1)->imag,(a11 + rs_a*1)->real, + (a11 + rs_a*1)->imag,(a11 + rs_a*1)->real, + (a11 + rs_a*1)->imag,(a11 + rs_a*1)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -45613,7 +50196,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm10 = _mm256_sub_ps(ymm10,ymm16); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 1)); + ymm1 = _mm256_set_ps((d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION @@ -45639,11 +50225,14 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB BLIS_SET_S_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_CTRSM_SMALL_GEMM_2nx4m(a01,b10,cs_b,p_lda,k_iter) + BLIS_CTRSM_SMALL_GEMM_2nx3m(a01,b10,cs_b,p_lda,k_iter) ymm16 = _mm256_broadcast_ps(( __m128 const *)(&AlphaVal)); ymm16 = _mm256_permute_ps(ymm16, 0x44); - ymm0 = _mm256_loadu_ps((float const *)(b11)); + xmm0 = _mm_loadu_ps((float const *)(b11)); + xmm1 = _mm_loadl_pi(xmm1, (__m64 const *)(b11 + 2)); + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0); + ymm0 = _mm256_insertf128_ps(ymm0, xmm1, 1); ymm17 = _mm256_shuffle_ps(ymm16, ymm16, 0x11); ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0); @@ -45652,7 +50241,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB ymm19 = _mm256_fmaddsub_ps(ymm18, ymm16, ymm19); ymm8 = _mm256_sub_ps(ymm19, ymm8); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b * 1)); + xmm0 = _mm_loadu_ps((float const *)(b11 + cs_b)); + xmm1 = _mm_loadl_pi(xmm1, (__m64 const *)(b11 + cs_b + 2)); + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0); + ymm0 = _mm256_insertf128_ps(ymm0, xmm1, 1); ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0); ymm19 = _mm256_shuffle_ps(ymm0, ymm0,0xF5); @@ -45661,14 +50253,20 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB ymm10 = _mm256_sub_ps(ymm19, ymm10); ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack)); + ymm1 = _mm256_set_ps((d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm8) #else BLIS_CTRSM_MUL(ymm8) #endif - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + rs_a*1) ); + ymm2 = _mm256_set_ps((a11 + rs_a*1)->imag,(a11 + rs_a*1)->real, + (a11 + rs_a*1)->imag,(a11 + rs_a*1)->real, + (a11 + rs_a*1)->imag,(a11 + rs_a*1)->real, + (a11 + rs_a*1)->imag,(a11 + rs_a*1)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -45685,7 +50283,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm10 = _mm256_sub_ps(ymm10,ymm16); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 1)); + ymm1 = _mm256_set_ps((d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION @@ -45719,11 +50320,13 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB BLIS_SET_S_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_CTRSM_SMALL_GEMM_2nx4m(a01,b10,cs_b,p_lda,k_iter) + BLIS_CTRSM_SMALL_GEMM_2nx2m(a01,b10,cs_b,p_lda,k_iter) ymm16 = _mm256_broadcast_ps(( __m128 const *)(&AlphaVal)); + ymm16 = _mm256_permute_ps(ymm16, 0x44); - ymm0 = _mm256_loadu_ps((float const *)(b11)); + xmm0 = _mm_loadu_ps((float const *)(b11)); + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0); ymm17 = _mm256_shuffle_ps(ymm16, ymm16, 0x11); ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0); @@ -45732,7 +50335,8 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB ymm19 = _mm256_fmaddsub_ps(ymm18, ymm16, ymm19); ymm8 = _mm256_sub_ps(ymm19, ymm8); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b * 1)); + xmm0 = _mm_loadu_ps((float const *)(b11 + cs_b)); + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0); ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0); ymm19 = _mm256_shuffle_ps(ymm0, ymm0,0xF5); @@ -45741,14 +50345,20 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB ymm10 = _mm256_sub_ps(ymm19, ymm10); ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack)); + ymm1 = _mm256_set_ps((d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm8) #else BLIS_CTRSM_MUL(ymm8) #endif - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + rs_a*1) ); + ymm2 = _mm256_set_ps((a11 + rs_a*1)->imag,(a11 + rs_a*1)->real, + (a11 + rs_a*1)->imag,(a11 + rs_a*1)->real, + (a11 + rs_a*1)->imag,(a11 + rs_a*1)->real, + (a11 + rs_a*1)->imag,(a11 + rs_a*1)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -45765,7 +50375,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm10 = _mm256_sub_ps(ymm10,ymm16); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 1)); + ymm1 = _mm256_set_ps((d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION @@ -45794,11 +50407,12 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB BLIS_SET_S_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_CTRSM_SMALL_GEMM_2nx4m(a01,b10,cs_b,p_lda,k_iter) + BLIS_CTRSM_SMALL_GEMM_2nx1m(a01,b10,cs_b,p_lda,k_iter) ymm16 = _mm256_broadcast_ps(( __m128 const *)(&AlphaVal)); ymm16 = _mm256_permute_ps(ymm16, 0x44); - ymm0 = _mm256_loadu_ps((float const *)(b11)); + xmm1 = _mm_loadl_pi(xmm1, (__m64 const *)(b11)); + ymm0 = _mm256_insertf128_ps(ymm0, xmm1, 0); ymm17 = _mm256_shuffle_ps(ymm16, ymm16, 0x11); ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0); @@ -45807,7 +50421,8 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB ymm19 = _mm256_fmaddsub_ps(ymm18, ymm16, ymm19); ymm8 = _mm256_sub_ps(ymm19, ymm8); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b * 1)); + xmm1 = _mm_loadl_pi(xmm1, (__m64 const *)(b11 + cs_b)); + ymm0 = _mm256_insertf128_ps(ymm0, xmm1, 0); ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0); ymm19 = _mm256_shuffle_ps(ymm0, ymm0,0xF5); @@ -45816,14 +50431,20 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB ymm10 = _mm256_sub_ps(ymm19, ymm10); ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack)); + ymm1 = _mm256_set_ps((d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm8) #else BLIS_CTRSM_MUL(ymm8) #endif - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + rs_a*1) ); + ymm2 = _mm256_set_ps((a11 + rs_a*1)->imag,(a11 + rs_a*1)->real, + (a11 + rs_a*1)->imag,(a11 + rs_a*1)->real, + (a11 + rs_a*1)->imag,(a11 + rs_a*1)->real, + (a11 + rs_a*1)->imag,(a11 + rs_a*1)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -45840,7 +50461,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm10 = _mm256_sub_ps(ymm10,ymm16); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 1)); + ymm1 = _mm256_set_ps((d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION @@ -45960,7 +50584,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack)); + ymm1 = _mm256_set_ps((d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_TWO_DIV(ymm8, ymm9) @@ -46000,7 +50627,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack)); + ymm1 = _mm256_set_ps((d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm8) @@ -46025,11 +50655,14 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB BLIS_SET_S_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_CTRSM_SMALL_GEMM_1nx4m(a01,b10,cs_b,p_lda,k_iter) + BLIS_CTRSM_SMALL_GEMM_1nx3m(a01,b10,cs_b,p_lda,k_iter) ymm16 = _mm256_broadcast_ps(( __m128 const *)(&AlphaVal)); ymm16 = _mm256_permute_ps(ymm16, 0x44); - ymm0 = _mm256_loadu_ps((float const *)(b11)); + xmm0 = _mm_loadu_ps((float const *)(b11)); + xmm1 = _mm_loadl_pi(xmm1, (__m64 const *)(b11 + 2)); + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0); + ymm0 = _mm256_insertf128_ps(ymm0, xmm1, 1); ymm17 = _mm256_shuffle_ps(ymm16, ymm16, 0x11); ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0); @@ -46040,7 +50673,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack)); + ymm1 = _mm256_set_ps((d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm8) @@ -46067,11 +50703,12 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB BLIS_SET_S_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_CTRSM_SMALL_GEMM_1nx4m(a01,b10,cs_b,p_lda,k_iter) + BLIS_CTRSM_SMALL_GEMM_1nx2m(a01,b10,cs_b,p_lda,k_iter) ymm16 = _mm256_broadcast_ps(( __m128 const *)(&AlphaVal)); ymm16 = _mm256_permute_ps(ymm16, 0x44); - ymm0 = _mm256_loadu_ps((float const *)(b11)); + xmm0 = _mm_loadu_ps((float const *)(b11)); + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0); ymm17 = _mm256_shuffle_ps(ymm16, ymm16, 0x11); ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0); @@ -46082,7 +50719,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack)); + ymm1 = _mm256_set_ps((d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm8) @@ -46107,11 +50747,12 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB BLIS_SET_S_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_CTRSM_SMALL_GEMM_1nx4m(a01,b10,cs_b,p_lda,k_iter) + BLIS_CTRSM_SMALL_GEMM_1nx1m(a01,b10,cs_b,p_lda,k_iter) ymm16 = _mm256_broadcast_ps(( __m128 const *)(&AlphaVal)); ymm16 = _mm256_permute_ps(ymm16, 0x44); - ymm0 = _mm256_loadu_ps((float const *)(b11)); + xmm1 = _mm_loadl_pi(xmm1, (__m64 const *)(b11)); + ymm0 = _mm256_insertf128_ps(ymm0, xmm1, 0); ymm17 = _mm256_shuffle_ps(ymm16, ymm16, 0x11); ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0); @@ -46121,7 +50762,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB ymm8 = _mm256_sub_ps(ymm19, ymm8); ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack)); + ymm1 = _mm256_set_ps((d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm8) @@ -46148,4 +50792,29 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB return BLIS_SUCCESS; } + +/* + * Check if the TRSM small path should be taken for this + * input and threads combination + */ +bool bli_cntx_trsm_small_thresh_is_met_zen(obj_t* a,dim_t m, dim_t n) +{ + rntm_t rntm; + bli_rntm_init_from_global(&rntm); + dim_t n_threads = bli_rntm_num_threads(&rntm); + + if(bli_obj_is_dcomplex(a)) + { + if ((n_threads > 1) && (n_threads <= 8) && (m <= 500) && (n <= 500)) + { + return true; + } + else + { + return false; + } + } + return false; +} + #endif //BLIS_ENABLE_SMALL_MATRIX_TRSM diff --git a/kernels/zen/3/bli_zgemm_ref_k1.c b/kernels/zen/3/bli_zgemm_ref_k1.c new file mode 100644 index 0000000000..47de706238 --- /dev/null +++ b/kernels/zen/3/bli_zgemm_ref_k1.c @@ -0,0 +1,1826 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include +#include "blis.h" +#include "immintrin.h" + +#define Z_MR 4 +#define Z_NR 6 + +// Macros for the main loop for M +#define SCALE_ALPHA_REAL_M_LOOP(rin_0,rin_1,r_bcast,real_val) \ + r_bcast = _mm256_broadcast_sd((double const *)(&real_val)); \ + rin_0 = _mm256_mul_pd(rin_0,r_bcast); \ + rin_1 = _mm256_mul_pd(rin_1,r_bcast); \ + +#define SCALE_ALPHA_IMAG_M_LOOP(rout_0,rout_1,rin_0,rin_1,r_bcast,r_perm,imag_val) \ + r_perm = _mm256_permute4x64_pd(rin_0,0b10110001); \ + r_bcast = _mm256_set_pd(1.0,-1.0,1.0,-1.0); \ + r_perm = _mm256_mul_pd(r_bcast, r_perm); \ + r_bcast = _mm256_broadcast_sd((double const *)(&imag_val)); \ + rout_0 = _mm256_fmadd_pd(r_perm,r_bcast,rout_0); \ + r_perm = _mm256_permute4x64_pd(rin_1,0b10110001); \ + r_bcast = _mm256_set_pd(1.0,-1.0,1.0,-1.0); \ + r_perm = _mm256_mul_pd(r_bcast, r_perm); \ + r_bcast = _mm256_broadcast_sd((double const *)(&imag_val)); \ + rout_1 = _mm256_fmadd_pd(r_perm,r_bcast,rout_1); \ + +#define NEG_PERM_M_LOOP(r0,r1,r2) \ + r0 = _mm256_permute4x64_pd(r0,0b10110001); \ + r1 = _mm256_permute4x64_pd(r1,0b10110001); \ + r2 = _mm256_set_pd(1.0,-1.0,1.0,-1.0); \ + r0 = _mm256_mul_pd(r2, r0); \ + r1 = _mm256_mul_pd(r2, r1); \ + +#define FMA_M_LOOP(rin_0,rin_1,rout_0,rout_1,rbc,loc) \ + rbc = _mm256_broadcast_sd(loc); \ + rout_0 = _mm256_fmadd_pd(rbc, rin_0, rout_0); \ + rout_1 = _mm256_fmadd_pd(rbc, rin_1, rout_1); \ + +#define SCALE_BETA_REAL_M_LOOP(rin_0,rin_1,rout_0,rout_1,rbc) \ + rout_0 = _mm256_fmadd_pd(rbc, rin_0, rout_0); \ + rout_1 = _mm256_fmadd_pd(rbc, rin_1, rout_1); \ + +#define SCALE_BETA_IMAG_M_LOOP(rin_0,rin_1,rout_0,rout_1,rbc,rn) \ + NEG_PERM_M_LOOP(rin_0,rin_1,rn); \ + rout_0 = _mm256_fmadd_pd(rbc, rin_0, rout_0); \ + rout_1 = _mm256_fmadd_pd(rbc, rin_1, rout_1); \ + +// Macros for fringe cases with M +#define SCALE_ALPHA_REAL_M_FRINGE(rin_0,r_bcast,real_val) \ + r_bcast = _mm256_broadcast_sd((double const *)(&real_val)); \ + rin_0 = _mm256_mul_pd(rin_0,r_bcast); \ + +#define SCALE_ALPHA_IMAG_M_FRINGE(rout_0,rin_0,r_bcast,r_perm,imag_val) \ + r_perm = _mm256_permute4x64_pd(rin_0,0b10110001); \ + r_bcast = _mm256_set_pd(1.0,-1.0,1.0,-1.0); \ + r_perm = _mm256_mul_pd(r_bcast, r_perm); \ + r_bcast = _mm256_broadcast_sd((double const *)(&imag_val)); \ + rout_0 = _mm256_fmadd_pd(r_perm,r_bcast,rout_0); \ + +#define NEG_PERM_M_FRINGE(r0,r2) \ + r0 = _mm256_permute4x64_pd(r0,0b10110001); \ + r2 = _mm256_set_pd(1.0,-1.0,1.0,-1.0); \ + r0 = _mm256_mul_pd(r2, r0); \ + +#define FMA_M_FRINGE(r_in,r_out,r_bc,loc) \ + r_bc = _mm256_broadcast_sd(loc); \ + r_out = _mm256_fmadd_pd(r_bc, r_in, r_out); \ + +#define SCALE_BETA_REAL_M_FRINGE(rin_0,rout_0,rbc) \ + rout_0 = _mm256_fmadd_pd(rbc, rin_0, rout_0); \ + +#define SCALE_BETA_IMAG_M_FRINGE(rin_0,rout_0,rbc,rn) \ + NEG_PERM_M_FRINGE(rin_0,rn); \ + rout_0 = _mm256_fmadd_pd(rbc, rin_0, rout_0); \ + +void bli_zgemm_ref_k1_nn +( + dim_t m, + dim_t n, + dim_t k, + dcomplex* alpha, + dcomplex* a, const inc_t lda, + dcomplex* b, const inc_t ldb, + dcomplex* beta, + dcomplex* c, const inc_t ldc + ) +{ + + double alpha_real, beta_real; + double alpha_imag, beta_imag; + + alpha_real = alpha->real; + beta_real = beta->real; + alpha_imag = alpha->imag; + beta_imag = beta->imag; + + /* If m or n is zero, return immediately. */ + if ( bli_zero_dim2( m, n ) ) return; + /* If alpha alone is zero, scale by beta and return. */ + if (bli_zeq0(*(alpha))) + { + bli_zscalm( + BLIS_NO_CONJUGATE, + 0, + BLIS_NONUNIT_DIAG, + BLIS_DENSE, + m, n, + beta, + c, 1, ldc + ); + return; + } + + dim_t m_remainder = (m % Z_MR); + dim_t n_remainder = (n % Z_NR); + + //scratch registers + __m256d ymm0, ymm1, ymm2, ymm3; + __m256d ymm4, ymm5, ymm6, ymm7; + __m256d ymm8, ymm9, ymm10, ymm11; + __m256d ymm12, ymm13, ymm14, ymm15; + __m128d xmm5; + + /* Form C = alpha*A*B + beta*c */ + // Main loop along N dimension + for(dim_t j = 0;j < (n-Z_NR+1);j=j+Z_NR) + { + dcomplex* temp_b = b + j*ldb; + dcomplex* temp_a = a; + dcomplex* temp_c = c + j*ldc; + + //Main loop along M dimension + for(dim_t i = 0;i < (m-Z_MR+1);i=i+Z_MR) + { + ymm3 = _mm256_setzero_pd(); + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm11 = _mm256_setzero_pd(); + ymm12 = _mm256_setzero_pd(); + + /* + a. Perform alpha*A*B using temp_a, temp_b and alpha_real, alpha_imag + where alpha_real and/or alpha_imag is not zero. + b. This loop operates with 4x6 block size + along n dimension for every Z_NR columns of temp_b where + computing all Z_MR rows of temp_a. + c. Same approach is used in remaining fringe cases. + */ + //R(a[0][0]) I(a[0][0]) R(a[1][0]) I(a[1][0]) + ymm0 = _mm256_loadu_pd((double const *)(temp_a)); + //R(a[2][0]) I(a[2][0]) R(a[3][0]) I(a[3][0]) + ymm1 = _mm256_loadu_pd((double const *)(temp_a + 2)); + + ymm13 = ymm0; + ymm14 = ymm1; + _mm_prefetch((char*)(temp_a + 32), _MM_HINT_T0); + + SCALE_ALPHA_REAL_M_LOOP(ymm0,ymm1,ymm15,alpha_real); + SCALE_ALPHA_IMAG_M_LOOP(ymm0,ymm1,ymm13,ymm14,ymm15,ymm2,alpha_imag); + + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + + /* + The result after scaling with alpha_real and/or alpha_imag is as follows: + For ymm0 : + R(a[0][0]) = alpha_real*R(a[0][0])-alpha_imag*I(a[0][0]) + I(a[0][0]) = alpha_real*I(a[0][0])+alpha_imag*R[0][0] + R(a[1][0]) = alpha_real*R(a[1][0])-alpha_imag*I(a[1][0]) + I(a[1][0]) = alpha_real*I(a[1][0])+alpha_imag*(R[1][0]) + + For ymm1 : + R(a[2][0]) = alpha_real*R(a[2][0])-alpha_imag*I(a[2][0]) + I(a[2][0]) = alpha_real*I(a[2][0])+alpha_imag*R[2][0] + R(a[3][0]) = alpha_real*R(a[3][0])-alpha_imag*I(a[3][0]) + I(a[3][0]) = alpha_real*I(a[3][0])+alpha_imag*(R[3][0]) + */ + + //Calculating using real part of complex number in B matrix + //ymm3+=R(b[0][0])*R(a[0][0]) R(b[0][0])*I(a[0][0]) + // R(b[0][0])*R(a[1][0]) R(b[0][0])*I(a[1][0]) + //ymm4+=R(b[0][0])*R(a[2][0]) R(b[0][0])*I(a[2][0]) + // R(b[0][0])*R(a[3][0]) R(b[0][0])*I(a[3][0]) + FMA_M_LOOP(ymm0,ymm1,ymm3,ymm4,ymm2,(double const *)(temp_b)); + //ymm5+=R(b[0][1])*R(a[0][0]) R(b[0][1])*I(a[0][0]) + // R(b[0][1])*R(a[1][0]) R(b[0][1])*I(a[1][0]) + //ymm6+=R(b[0][1])*R(a[0][0]) R(b[0][1])*I(a[0][0]) + // R(b[0][1])*R(a[1][0]) R(b[0][1])*I(a[1][0]) + FMA_M_LOOP(ymm0,ymm1,ymm5,ymm6,ymm2,(double const *)(temp_b+ldb)); + //ymm7+=R(b[0][2])*R(a[0][0]) R(b[0][2])*I(a[0][0]) + // R(b[0][2])*R(a[1][0]) R(b[0][2])*I(a[1][0]) + //ymm8+=R(b[0][2])*R(a[0][0]) R(b[0][2])*I(a[0][0]) + // R(b[0][2])*R(a[1][0]) R(b[0][2])*I(a[1][0]) + FMA_M_LOOP(ymm0,ymm1,ymm7,ymm8,ymm2,(double const *)(temp_b+ldb*2)); + //ymm9+=R(b[0][3])*R(a[0][0]) R(b[0][3])*I(a[0][0]) + // R(b[0][3])*R(a[1][0]) R(b[0][3])*I(a[1][0]) + //ymm10+=R(b[0][3])*R(a[0][0]) R(b[0][3])*I(a[0][0]) + // R(b[0][3])*R(a[1][0]) R(b[0][3])*I(a[1][0]) + FMA_M_LOOP(ymm0,ymm1,ymm9,ymm10,ymm2,(double const *)(temp_b+ldb*3)); + //ymm11+=R(b[0][4])*R(a[0][0]) R(b[0][4])*I(a[0][0]) + // R(b[0][4])*R(a[1][0]) R(b[0][4])*I(a[1][0]) + //ymm12+=R(b[0][4])*R(a[0][0]) R(b[0][4])*I(a[0][0]) + // R(b[0][4])*R(a[1][0]) R(b[0][4])*I(a[1][0]) + FMA_M_LOOP(ymm0,ymm1,ymm11,ymm12,ymm2,(double const *)(temp_b+ldb*4)); + //ymm11+=R(b[0][5])*R(a[0][0]) R(b[0][5])*I(a[0][0]) + // R(b[0][5])*R(a[1][0]) R(b[0][5])*I(a[1][0]) + //ymm12+=R(b[0][5])*R(a[0][0]) R(b[0][5])*I(a[0][0]) + // R(b[0][5])*R(a[1][0]) R(b[0][5])*I(a[1][0]) + FMA_M_LOOP(ymm0,ymm1,ymm13,ymm14,ymm2,(double const *)(temp_b+ldb*5)); + + //Calculating using imaginary part of complex numbers in B matrix + //Shuffling ymm0 and ymm1 in accordance to the requirement + NEG_PERM_M_LOOP(ymm0,ymm1,ymm2); + //ymm3+=I(b[0][0])*R(a[0][0]) I(b[0][0])*I(a[0][0]) + // I(b[0][0])*R(a[1][0]) I(b[0][0])*I(a[1][0]) + //ymm4+=R(b[0][0])*R(a[2][0]) I(b[0][0])*I(a[2][0]) + // I(b[0][0])*R(a[3][0]) I(b[0][0])*I(a[3][0]) + FMA_M_LOOP(ymm0,ymm1,ymm3,ymm4,ymm2,(double const *)(temp_b)+1); + //ymm5+=I(b[0][1])*R(a[0][0]) I(b[0][1])*I(a[0][0]) + // I(b[0][1])*R(a[1][0]) I(b[0][1])*I(a[1][0]) + //ymm6+=R(b[0][1])*R(a[0][0]) I(b[0][1])*I(a[0][0]) + // I(b[0][1])*R(a[1][0]) I(b[0][1])*I(a[1][0]) + FMA_M_LOOP(ymm0,ymm1,ymm5,ymm6,ymm2,(double const *)(temp_b+ldb)+1); + //ymm7+=I(b[0][2])*R(a[0][0]) I(b[0][2])*I(a[0][0]) + // I(b[0][2])*R(a[1][0]) I(b[0][2])*I(a[1][0]) + //ymm8+=I(b[0][2])*R(a[0][0]) I(b[0][2])*I(a[0][0]) + // I(b[0][2])*R(a[1][0]) I(b[0][2])*I(a[1][0]) + FMA_M_LOOP(ymm0,ymm1,ymm7,ymm8,ymm2,(double const *)(temp_b+ldb*2)+1); + //ymm9+=I(b[0][3])*R(a[0][0]) I(b[0][3])*I(a[0][0]) + // I(b[0][3])*R(a[1][0]) I(b[0][3])*I(a[1][0]) + //ymm10+=I(b[0][3])*R(a[0][0]) I(b[0][3])*I(a[0][0]) + // I(b[0][3])*R(a[1][0]) I(b[0][3])*I(a[1][0]) + FMA_M_LOOP(ymm0,ymm1,ymm9,ymm10,ymm2,(double const *)(temp_b+ldb*3)+1); + //ymm11+=I(b[0][4])*R(a[0][0]) I(b[0][4])*I(a[0][0]) + // I(b[0][4])*R(a[1][0]) I(b[0][4])*I(a[1][0]) + //ymm12+=I(b[0][4])*R(a[0][0]) I(b[0][4])*I(a[0][0]) + // I(b[0][4])*R(a[1][0]) I(b[0][4])*I(a[1][0]) + FMA_M_LOOP(ymm0,ymm1,ymm11,ymm12,ymm2,(double const *)(temp_b+ldb*4)+1); + //ymm13+=I(b[0][5])*R(a[0][0]) I(b[0][5])*I(a[0][0]) + // I(b[0][5])*R(a[1][0]) I(b[0][5])*I(a[1][0]) + //ymm14+=I(b[0][5])*R(a[0][0]) I(b[0][5])*I(a[0][0]) + // I(b[0][5])*R(a[1][0]) I(b[0][5])*I(a[1][0]) + FMA_M_LOOP(ymm0,ymm1,ymm13,ymm14,ymm2,(double const *)(temp_b+ldb*5)+1); + + /* + a. Perform beta*C using temp_c, beta_real, + where beta_real is not zero. + b. This loop operates with 4x6 block size + along n dimension for every Z_NR columns of temp_c where + computing all Z_MR rows of temp_c. + c. Accumulated alpha*A*B into registers will be added to beta*C + d. Same approach is used in remaining fringe cases. + */ + if(beta_real != 0.0) + { + ymm15 = _mm256_broadcast_sd((double const *)(&beta_real)); + + //R(c[0][0]) I(c[0][0]) R(c[1][0]) I(c[1][0]) + ymm0 = _mm256_loadu_pd((double const *)(temp_c)); + //R(c[2][0]) I(c[2][0]) R(c[3][0]) I(c[3][0]) + ymm1 = _mm256_loadu_pd((double const *)(temp_c + 2)); + //ymm3+=beta_real*R(c[0][0]) beta_real*I(c[0][0]) + // beta_real*R(c[1][0]) beta_real*I(c[1][0]) + //ymm4+=beta_real*R(c[2][0]) beta_real*I(c[2][0]) + // beta_real*R(c[3][0]) beta_real*I(c[3][0]) + SCALE_BETA_REAL_M_LOOP(ymm0,ymm1,ymm3,ymm4,ymm15); + + //R(c[0][1]) I(c[0][1]) R(c[1][1]) I(c[1][1]) + ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc)); + //R(c[2][1]) I(c[2][1]) R(c[3][1]) I(c[3][1]) + ymm1 = _mm256_loadu_pd((double const *)(temp_c + ldc + 2)); + //ymm5+=beta_real*R(c[0][1]) beta_real*I(c[0][1]) + // beta_real*R(c[1][1]) beta_real*I(c[1][1]) + //ymm6+=beta_real*R(c[2][1]) beta_real*I(c[2][1]) + // beta_real*R(c[3][1]) beta_real*I(c[3][1]) + SCALE_BETA_REAL_M_LOOP(ymm0,ymm1,ymm5,ymm6,ymm15); + + //R(c[0][2]) I(c[0][2]) R(c[1][2]) I(c[1][2]) + ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc*2)); + //R(c[2][2]) I(c[2][2]) R(c[3][2]) I(c[3][2]) + ymm1 = _mm256_loadu_pd((double const *)(temp_c + ldc*2 + 2)); + //ymm7+=beta_real*R(c[0][2]) beta_real*I(c[0][2]) + // beta_real*R(c[1][2]) beta_real*I(c[1][2]) + //ymm8+=beta_real*R(c[2][2]) beta_real*I(c[2][2]) + //beta_real*R(c[3][2]) beta_real*I(c[3][2]) + SCALE_BETA_REAL_M_LOOP(ymm0,ymm1,ymm7,ymm8,ymm15); + + //R(c[0][3]) I(c[0][3]) R(c[1][3]) I(c[1][3]) + ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc*3)); + //R(c[2][3]) I(c[2][3]) R(c[3][3]) I(c[3][3]) + ymm1 = _mm256_loadu_pd((double const *)(temp_c + ldc*3 + 2)); + //ymm9+=beta_real*R(c[0][3]) beta_real*I(c[0][3]) + // beta_real*R(c[1][3]) beta_real*I(c[1][3]) + //ymm10+=beta_real*R(c[2][3]) beta_real*I(c[2][3]) + // beta_real*R(c[3][3]) beta_real*I(c[3][3]) + SCALE_BETA_REAL_M_LOOP(ymm0,ymm1,ymm9,ymm10,ymm15); + + //R(c[0][4]) I(c[0][4]) R(c[1][4]) I(c[1][4]) + ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc*4)); + //R(c[2][4]) I(c[2][4]) R(c[3][4]) I(c[3][4]) + ymm1 = _mm256_loadu_pd((double const *)(temp_c + ldc*4 + 2)); + //ymm11+=beta_real*R(c[0][4]) beta_real*I(c[0][4]) + // beta_real*R(c[1][4]) beta_real*I(c[1][4]) + //ymm12+=beta_real*R(c[2][4]) beta_real*I(c[2][4]) + // beta_real*R(c[3][4]) beta_real*I(c[3][4]) + SCALE_BETA_REAL_M_LOOP(ymm0,ymm1,ymm11,ymm12,ymm15); + + //R(c[0][5]) I(c[0][5]) R(c[1][5]) I(c[1][5]) + ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc*5)); + //R(c[2][5]) I(c[2][5]) R(c[3][5]) I(c[3][5]) + ymm1 = _mm256_loadu_pd((double const *)(temp_c + ldc*5 + 2)); + //ymm13+=beta_real*R(c[0][5]) beta_real*I(c[0][5]) + // beta_real*R(c[1][5]) beta_real*I(c[1][5]) + //ymm14+=beta_real*R(c[2][5]) beta_real*I(c[2][5]) + // beta_real*R(c[3][5]) beta_real*I(c[3][5]) + SCALE_BETA_REAL_M_LOOP(ymm0,ymm1,ymm13,ymm14,ymm15); + } + + /* + a. Perform beta*C using temp_c, beta_imag, + where beta_imag is not zero. + b. This loop operates with 4x6 block size + along n dimension for every Z_NR columns of temp_c where + computing all Z_MR rows of temp_c. + c. Accumulated alpha*A*B into registers will be added to beta*C + d. Same approach is used in remaining fringe cases. + */ + + if(beta_imag != 0.0) + { + ymm15 = _mm256_broadcast_sd((double const *)(&beta_imag)); + + //R(c[0][0]) I(c[0][0]) R(c[1][0]) I(c[1][0]) + ymm0 = _mm256_loadu_pd((double const *)(temp_c)); + //R(c[2][0]) I(c[2][0]) R(c[3][0]) I(c[3][0]) + ymm1 = _mm256_loadu_pd((double const *)(temp_c + 2)); + //ymm3+=beta_imag*(-I(c[0][0])) beta_imag*R(c[0][0]) + // beta_imag*(-I(c[1][0])) beta_imag*R(c[1][0]) + //ymm4+=beta_imag*(-I(c[2][0])) beta_imag*R(c[2][0]) + // beta_imag*(-I(c[3][0])) beta_imag*R(c[3][0]) + SCALE_BETA_IMAG_M_LOOP(ymm0,ymm1,ymm3,ymm4,ymm15,ymm2); + + //R(c[0][1]) I(c[0][1]) R(c[1][1]) I(c[1][1]) + ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc)); + //R(c[2][1]) I(c[2][1]) R(c[3][1]) I(c[3][1]) + ymm1 = _mm256_loadu_pd((double const *)(temp_c + ldc + 2)); + //ymm5+=beta_imag*(-I(c[0][1])) beta_imag*R(c[0][1]) + // beta_imag*(-I(c[1][1])) beta_imag*R(c[1][1]) + //ymm6+=beta_imag*(-I(c[2][1])) beta_imag*R(c[2][1]) + // beta_imag*(-I(c[3][1])) beta_imag*R(c[3][1]) + SCALE_BETA_IMAG_M_LOOP(ymm0,ymm1,ymm5,ymm6,ymm15,ymm2); + + //R(c[0][2]) I(c[0][2]) R(c[1][2]) I(c[1][2]) + ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc*2)); + //R(c[2][2]) I(c[2][2]) R(c[3][2]) I(c[3][2]) + ymm1 = _mm256_loadu_pd((double const *)(temp_c + ldc*2 + 2)); + //ymm7+=beta_imag*(-I(c[0][2])) beta_imag*R(c[0][2]) + // beta_imag*(-I(c[1][2])) beta_imag*R(c[1][2]) + //ymm8+=beta_imag*(-I(c[2][2])) beta_imag*R(c[2][2]) + // beta_imag*(-I(c[3][2])) beta_imag*R(c[3][2]) + SCALE_BETA_IMAG_M_LOOP(ymm0,ymm1,ymm7,ymm8,ymm15,ymm2); + + //R(c[0][3]) I(c[0][3]) R(c[1][3]) I(c[1][3]) + ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc*3)); + //R(c[2][3]) I(c[2][3]) R(c[3][3]) I(c[3][3]) + ymm1 = _mm256_loadu_pd((double const *)(temp_c + ldc*3 + 2)); + //ymm9+=beta_imag*(-I(c[0][3])) beta_imag*R(c[0][3]) + // beta_imag*(-I(c[1][3])) beta_imag*R(c[1][3]) + //ymm10+=beta_imag*(-I(c[2][3])) beta_imag*R(c[2][3]) + // beta_imag*(-I(c[3][3])) beta_imag*R(c[3][3]) + SCALE_BETA_IMAG_M_LOOP(ymm0,ymm1,ymm9,ymm10,ymm15,ymm2); + + //R(c[0][4]) I(c[0][4]) R(c[1][4]) I(c[1][4]) + ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc*4)); + //R(c[2][4]) I(c[2][4]) R(c[3][4]) I(c[3][4]) + ymm1 = _mm256_loadu_pd((double const *)(temp_c + ldc*4 + 2)); + //ymm11+=beta_imag*(-I(c[0][4])) beta_imag*R(c[0][4]) + // beta_imag*(-I(c[1][4])) beta_imag*R(c[1][4]) + //ymm12+=beta_imag*(-I(c[2][4])) beta_imag*R(c[2][4]) + // beta_imag*(-I(c[3][4])) beta_imag*R(c[3][4]) + SCALE_BETA_IMAG_M_LOOP(ymm0,ymm1,ymm11,ymm12,ymm15,ymm2); + + //R(c[0][5]) I(c[0][5]) R(c[1][5]) I(c[1][5]) + ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc*5)); + //R(c[2][5]) I(c[2][5]) R(c[3][5]) I(c[3][5]) + ymm1 = _mm256_loadu_pd((double const *)(temp_c + ldc*5 + 2)); + //ymm13+=beta_imag*(-I(c[0][5])) beta_imag*R(c[0][5]) + // beta_imag*(-I(c[1][5])) beta_imag*R(c[1][5]) + //ymm14+=beta_imag*(-I(c[2][5])) beta_imag*R(c[2][5]) + // beta_imag*(-I(c[3][5])) beta_imag*R(c[3][5]) + SCALE_BETA_IMAG_M_LOOP(ymm0,ymm1,ymm13,ymm14,ymm15,ymm2); + } + /* + The scaling has been done sequentially as follows: + - If alpha_real is not 0, it is used for scaling A + - If alpha_imag is not 0, it is used for scaling A using permutation + and selective negation, after loading + - If beta_real is not 0, is is used for scaling C + - If beta_imag is not 0, it is used for scaling C using permutation + and selective negation, after loading + + The results are accumalated in accordance to the non zero scalar values, + and similar approach is followed in fringe cases + */ + + _mm256_storeu_pd((double *)(temp_c), ymm3); + _mm256_storeu_pd((double *)(temp_c + 2), ymm4); + + _mm256_storeu_pd((double *)(temp_c + ldc), ymm5); + _mm256_storeu_pd((double *)(temp_c + ldc + 2), ymm6); + + _mm256_storeu_pd((double *)(temp_c + ldc*2), ymm7); + _mm256_storeu_pd((double *)(temp_c + ldc*2 + 2), ymm8); + + _mm256_storeu_pd((double *)(temp_c + ldc*3), ymm9); + _mm256_storeu_pd((double *)(temp_c + ldc*3 + 2), ymm10); + + _mm256_storeu_pd((double *)(temp_c + ldc*4), ymm11); + _mm256_storeu_pd((double *)(temp_c + ldc*4 + 2), ymm12); + + _mm256_storeu_pd((double *)(temp_c + ldc*5), ymm13); + _mm256_storeu_pd((double *)(temp_c + ldc*5 + 2), ymm14); + + temp_c+=Z_MR; + temp_a+=Z_MR; + } + + // Fringe cases for M + dim_t m_rem=m_remainder; + if(m_rem>=2) + { + ymm3 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm11 = _mm256_setzero_pd(); + + //R(a[0][0]) I(a[0][0]) R(a[1][0]) I(a[1][0]) + ymm0 = _mm256_loadu_pd((double const *)(temp_a)); + + ymm13 = ymm0; + SCALE_ALPHA_REAL_M_FRINGE(ymm0,ymm15,alpha_real); + SCALE_ALPHA_IMAG_M_FRINGE(ymm0,ymm13,ymm15,ymm2,alpha_imag); + + ymm13 = _mm256_setzero_pd(); + + /* + The result after scaling with alpha_real and/or alpha_imag is as follows: + For ymm0 : + R(a[0][0]) = alpha_real*R(a[0][0])-alpha_imag*I(a[0][0]) + I(a[0][0]) = alpha_real*I(a[0][0])+alpha_imag*R[0][0] + R(a[1][0]) = alpha_real*R(a[1][0])-alpha_imag*I(a[1][0]) + I(a[1][0]) = alpha_real*I(a[1][0])+alpha_imag*(R[1][0]) + */ + + //Calculating using real part of complex number in B matrix + //ymm3+=R(b[0][0])*R(a[0][0]) R(b[0][0])*I(a[0][0]) + // R(b[0][0])*R(a[1][0]) R(b[0][0])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm3,ymm2,(double const *)(temp_b)); + //ymm5+=R(b[0][1])*R(a[0][0]) R(b[0][1])*I(a[0][0]) + // R(b[0][1])*R(a[1][0]) R(b[0][1])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm5,ymm2,(double const *)(temp_b+ldb)); + //ymm7+=R(b[0][2])*R(a[0][0]) R(b[0][2])*I(a[0][0]) + // R(b[0][2])*R(a[1][0]) R(b[0][2])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm7,ymm2,(double const *)(temp_b+ldb*2)); + //ymm9+=R(b[0][3])*R(a[0][0]) R(b[0][3])*I(a[0][0]) + // R(b[0][3])*R(a[1][0]) R(b[0][3])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm9,ymm2,(double const *)(temp_b+ldb*3)); + //ymm11+=R(b[0][4])*R(a[0][0]) R(b[0][4])*I(a[0][0]) + // R(b[0][4])*R(a[1][0]) R(b[0][4])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm11,ymm2,(double const *)(temp_b+ldb*4)); + //ymm13+=R(b[0][5])*R(a[0][0]) R(b[0][5])*I(a[0][0]) + // R(b[0][5])*R(a[1][0]) R(b[0][5])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm13,ymm2,(double const *)(temp_b+ldb*5)); + + //Calculating using imaginary part of complex numbers in B matrix + //Shuffling ymm0 in accordance to the requirement + NEG_PERM_M_FRINGE(ymm0,ymm2); + + // ymm3+=I(b[0][0])*R(a[0][0]) I(b[0][0])*I(a[0][0]) + // I(b[0][0])*R(a[1][0]) I(b[0][0])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm3,ymm2,(double const *)(temp_b)+1); + //ymm5+=I(b[0][1])*R(a[0][0]) I(b[0][1])*I(a[0][0]) + // I(b[0][1])*R(a[1][0]) I(b[0][1])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm5,ymm2,(double const *)(temp_b+ldb)+1); + //ymm7+=I(b[0][2])*R(a[0][0]) I(b[0][2])*I(a[0][0]) + // I(b[0][2])*R(a[1][0]) I(b[0][2])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm7,ymm2,(double const *)(temp_b+ldb*2)+1); + //ymm9+=I(b[0][3])*R(a[0][0]) I(b[0][3])*I(a[0][0]) + // I(b[0][3])*R(a[1][0]) I(b[0][3])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm9,ymm2,(double const *)(temp_b+ldb*3)+1); + //ymm11+=I(b[0][4])*R(a[0][0]) I(b[0][4])*I(a[0][0]) + // I(b[0][4])*R(a[1][0]) I(b[0][4])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm11,ymm2,(double const *)(temp_b+ldb*4)+1); + //ymm13+=I(b[0][5])*R(a[0][0]) I(b[0][5])*I(a[0][0]) + // I(b[0][5])*R(a[1][0]) I(b[0][5])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm13,ymm2,(double const *)(temp_b+ldb*5)+1); + + + if(beta_real != 0.0) + { + ymm15 = _mm256_broadcast_sd((double const *)(&beta_real)); + + //R(c[0][0]) I(c[0][0]) R(c[1][0]) I(c[1][0]) + ymm0 = _mm256_loadu_pd((double const *)(temp_c)); + //ymm3+=beta_real*R(c[0][0]) beta_real*I(c[0][0]) + // beta_real*R(c[1][0]) beta_real*I(c[1][0]) + SCALE_BETA_REAL_M_FRINGE(ymm0,ymm3,ymm15); + + //R(c[0][1]) I(c[0][1]) R(c[1][1]) I(c[1][1]) + ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc)); + //ymm5+=beta_real*R(c[0][1]) beta_real*I(c[0][1]) + // beta_real*R(c[1][1]) beta_real*I(c[1][1]) + SCALE_BETA_REAL_M_FRINGE(ymm0,ymm5,ymm15); + + //R(c[0][2]) I(c[0][2]) R(c[1][2]) I(c[1][2]) + ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc*2)); + //ymm7+=beta_real*R(c[0][2]) beta_real*I(c[0][2]) + // beta_real*R(c[1][2]) beta_real*I(c[1][2]) + SCALE_BETA_REAL_M_FRINGE(ymm0,ymm7,ymm15); + + //R(c[0][3]) I(c[0][3]) R(c[1][3]) I(c[1][3]) + ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc*3)); + //ymm9+=beta_real*R(c[0][3]) beta_real*I(c[0][3]) + // beta_real*R(c[1][3]) beta_real*I(c[1][3]) + SCALE_BETA_REAL_M_FRINGE(ymm0,ymm9,ymm15); + + //R(c[0][4]) I(c[0][4]) R(c[1][4]) I(c[1][4]) + ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc*4)); + //ymm11+=beta_real*R(c[0][4]) beta_real*I(c[0][4]) + // beta_real*R(c[1][4]) beta_real*I(c[1][4]) + SCALE_BETA_REAL_M_FRINGE(ymm0,ymm11,ymm15); + + //R(c[0][5]) I(c[0][5]) R(c[1][5]) I(c[1][5]) + ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc*5)); + //ymm13+=beta_real*R(c[0][5]) beta_real*I(c[0][5]) + // beta_real*R(c[1][5]) beta_real*I(c[1][5]) + SCALE_BETA_REAL_M_FRINGE(ymm0,ymm13,ymm15); + } + + + if(beta_imag != 0.0) + { + ymm15 = _mm256_broadcast_sd((double const *)(&beta_imag)); + + //R(c[0][0]) I(c[0][0]) R(c[1][0]) I(c[1][0]) + ymm0 = _mm256_loadu_pd((double const *)(temp_c)); + //ymm3+=beta_imag*(-I(c[0][0])) beta_imag*R(c[0][0]) + // beta_imag*(-I(c[1][0])) beta_imag*R(c[1][0]) + SCALE_BETA_IMAG_M_FRINGE(ymm0,ymm3,ymm15,ymm2); + + //R(c[0][1]) I(c[0][1]) R(c[1][1]) I(c[1][1]) + ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc)); + //ymm5+=beta_imag*(-I(c[0][1])) beta_imag*R(c[0][1]) + // beta_imag*(-I(c[1][1])) beta_imag*R(c[1][1]) + SCALE_BETA_IMAG_M_FRINGE(ymm0,ymm5,ymm15,ymm2); + + //R(c[0][2]) I(c[0][2]) R(c[1][2]) I(c[1][2]) + ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc*2)); + //ymm7+=beta_imag*(-I(c[0][2])) beta_imag*R(c[0][2]) + // beta_imag*(-I(c[1][2])) beta_imag*R(c[1][2]) + SCALE_BETA_IMAG_M_FRINGE(ymm0,ymm7,ymm15,ymm2); + + //R(c[0][3]) I(c[0][3]) R(c[1][3]) I(c[1][3]) + ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc*3)); + //ymm9+=beta_imag*(-I(c[0][3])) beta_imag*R(c[0][3]) + // beta_imag*(-I(c[1][3])) beta_imag*R(c[1][3]) + SCALE_BETA_IMAG_M_FRINGE(ymm0,ymm9,ymm15,ymm2); + + //R(c[0][4]) I(c[0][4]) R(c[1][4]) I(c[1][4]) + ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc*4)); + //ymm11+=beta_imag*(-I(c[0][4])) beta_imag*R(c[0][4]) + // beta_imag*(-I(c[1][4])) beta_imag*R(c[1][4]) + SCALE_BETA_IMAG_M_FRINGE(ymm0,ymm11,ymm15,ymm2); + + //R(c[0][5]) I(c[0][5]) R(c[1][5]) I(c[1][5]) + ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc*5)); + //ymm13+=beta_imag*(-I(c[0][5])) beta_imag*R(c[0][5]) + // beta_imag*(-I(c[1][5])) beta_imag*R(c[1][5]) + SCALE_BETA_IMAG_M_FRINGE(ymm0,ymm13,ymm15,ymm2); + } + + /* + The scaling has been done sequentially as follows: + - If alpha_real is not 0, it is used for scaling A + - If alpha_imag is not 0, it is used for scaling A using permutation + and selective negation, after loading + - If beta_real is not 0, is is used for scaling C + - If beta_imag is not 0, it is used for scaling C using permutation + and selective negation, after loading + + The results are accumalated in accordance to the non zero scalar values. + */ + + _mm256_storeu_pd((double *)(temp_c), ymm3); + _mm256_storeu_pd((double *)(temp_c + ldc), ymm5); + _mm256_storeu_pd((double *)(temp_c + ldc*2), ymm7); + _mm256_storeu_pd((double *)(temp_c + ldc*3), ymm9); + _mm256_storeu_pd((double *)(temp_c + ldc*4), ymm11); + _mm256_storeu_pd((double *)(temp_c + ldc*5), ymm13); + + temp_c+=2; + temp_a+=2; + + m_rem -= 2; + } + + if(m_rem==1) + { + + xmm5 = _mm_setzero_pd(); + ymm3 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm11 = _mm256_setzero_pd(); + + xmm5 = _mm_loadu_pd((double const*)(temp_a));//R(a[0][0]) I(a[0][0]) + ymm0 = _mm256_insertf128_pd(ymm0,xmm5,0);//R(a[0][0]) I(a[0][0]) + + ymm13 = ymm0; + SCALE_ALPHA_REAL_M_FRINGE(ymm0,ymm15,alpha_real); + SCALE_ALPHA_IMAG_M_FRINGE(ymm0,ymm13,ymm15,ymm2,alpha_imag); + + ymm13 = _mm256_setzero_pd(); + + //Calculating using real part of complex number in B matrix + //ymm3+=R(b[0][0])*R(a[0][0]) R(b[0][0])*I(a[0][0]) + // R(b[0][0])*R(a[1][0]) R(b[0][0])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm3,ymm2,(double const *)(temp_b)); + //ymm5+=R(b[0][1])*R(a[0][0]) R(b[0][1])*I(a[0][0]) + // R(b[0][1])*R(a[1][0]) R(b[0][1])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm5,ymm2,(double const *)(temp_b+ldb)); + //ymm7+=R(b[0][2])*R(a[0][0]) R(b[0][2])*I(a[0][0]) + // R(b[0][2])*R(a[1][0]) R(b[0][2])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm7,ymm2,(double const *)(temp_b+ldb*2)); + //ymm9+=R(b[0][3])*R(a[0][0]) R(b[0][3])*I(a[0][0]) + // R(b[0][3])*R(a[1][0]) R(b[0][3])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm9,ymm2,(double const *)(temp_b+ldb*3)); + //ymm11+=R(b[0][4])*R(a[0][0]) R(b[0][4])*I(a[0][0]) + // R(b[0][4])*R(a[1][0]) R(b[0][4])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm11,ymm2,(double const *)(temp_b+ldb*4)); + //ymm13+=R(b[0][5])*R(a[0][0]) R(b[0][5])*I(a[0][0]) + // R(b[0][5])*R(a[1][0]) R(b[0][5])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm13,ymm2,(double const *)(temp_b+ldb*5)); + + //Calculating using imaginary part of complex numbers in B matrix + //Shuffling ymm0 in accordance to the requirement + NEG_PERM_M_FRINGE(ymm0,ymm2); + + // ymm3+=I(b[0][0])*R(a[0][0]) I(b[0][0])*I(a[0][0]) + // I(b[0][0])*R(a[1][0]) I(b[0][0])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm3,ymm2,(double const *)(temp_b)+1); + //ymm5+=I(b[0][1])*R(a[0][0]) I(b[0][1])*I(a[0][0]) + // I(b[0][1])*R(a[1][0]) I(b[0][1])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm5,ymm2,(double const *)(temp_b+ldb)+1); + //ymm7+=I(b[0][2])*R(a[0][0]) I(b[0][2])*I(a[0][0]) + // I(b[0][2])*R(a[1][0]) I(b[0][2])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm7,ymm2,(double const *)(temp_b+ldb*2)+1); + //ymm9+=I(b[0][3])*R(a[0][0]) I(b[0][3])*I(a[0][0]) + // I(b[0][3])*R(a[1][0]) I(b[0][3])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm9,ymm2,(double const *)(temp_b+ldb*3)+1); + //ymm11+=I(b[0][4])*R(a[0][0]) I(b[0][4])*I(a[0][0]) + // I(b[0][4])*R(a[1][0]) I(b[0][4])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm11,ymm2,(double const *)(temp_b+ldb*4)+1); + //ymm13+=I(b[0][5])*R(a[0][0]) I(b[0][5])*I(a[0][0]) + // I(b[0][5])*R(a[1][0]) I(b[0][5])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm13,ymm2,(double const *)(temp_b+ldb*5)+1); + + if(beta_real != 0.0) + { + ymm15 = _mm256_broadcast_sd((double const *)(&beta_real)); + + xmm5 = _mm_loadu_pd((double const*)(temp_c));//R(c[0][0]) I(c[0][0]) + ymm0 = _mm256_insertf128_pd(ymm0,xmm5,0);//R(c[0][0]) I(c[0][0]) + //ymm3+=beta_real*R(c[0][0]) beta_real*I(c[0][0]) + SCALE_BETA_REAL_M_FRINGE(ymm0,ymm3,ymm15); + + xmm5 = _mm_loadu_pd((double const*)(temp_c + ldc));//R(c[0][1]) I(c[0][1]) + ymm0 = _mm256_insertf128_pd(ymm0,xmm5,0);//R(c[0][1]) I(c[0][1]) + //ymm5+=beta_real*R(c[0][1]) beta_real*I(c[0][1]) + SCALE_BETA_REAL_M_FRINGE(ymm0,ymm5,ymm15); + + xmm5 = _mm_loadu_pd((double const*)(temp_c + ldc * 2));//R(c[0][2]) I(c[0][2]) + ymm0 = _mm256_insertf128_pd(ymm0,xmm5,0);//R(c[0][2]) I(c[0][2]) + //ymm7+=beta_real*R(c[0][2]) beta_real*I(c[0][2]) + SCALE_BETA_REAL_M_FRINGE(ymm0,ymm7,ymm15); + + xmm5 = _mm_loadu_pd((double const*)(temp_c + ldc * 3));//R(c[0][3]) I(c[0][3]) + ymm0 = _mm256_insertf128_pd(ymm0,xmm5,0);//R(c[0][3]) I(c[0][3]) + //ymm9+=beta_real*R(c[0][3]) beta_real*I(c[0][3]) + SCALE_BETA_REAL_M_FRINGE(ymm0,ymm9,ymm15); + + xmm5 = _mm_loadu_pd((double const*)(temp_c + ldc * 4));//R(c[0][4]) I(c[0][4]) + ymm0 = _mm256_insertf128_pd(ymm0,xmm5,0);//R(c[0][4]) I(c[0][4]) + //ymm11+=beta_real*R(c[0][4]) beta_real*I(c[0][4]) + SCALE_BETA_REAL_M_FRINGE(ymm0,ymm11,ymm15); + + xmm5 = _mm_loadu_pd((double const*)(temp_c + ldc * 5));//R(c[0][5]) I(c[0][5]) + ymm0 = _mm256_insertf128_pd(ymm0,xmm5,0);//R(c[0][5]) I(c[0][5]) + //ymm13+=beta_real*R(c[0][5]) beta_real*I(c[0][5]) + SCALE_BETA_REAL_M_FRINGE(ymm0,ymm13,ymm15); + } + + if(beta_imag != 0.0) + { + ymm15 = _mm256_broadcast_sd((double const *)(&beta_imag)); + + xmm5 = _mm_loadu_pd((double const*)(temp_c));//R(c[0][0]) I(c[0][0]) + ymm0 = _mm256_insertf128_pd(ymm0,xmm5,0);//R(c[0][0]) I(c[0][0]) + //ymm3+=beta_imag*(-I(c[0][0])) beta_imag*R(c[0][0]) + SCALE_BETA_IMAG_M_FRINGE(ymm0,ymm3,ymm15,ymm2); + + xmm5 = _mm_loadu_pd((double const*)(temp_c + ldc));//R(c[0][1]) I(c[0][1]) + ymm0 = _mm256_insertf128_pd(ymm0,xmm5,0);//R(c[0][1]) I(c[0][1]) + //ymm5+=beta_imag*(-I(c[0][1])) beta_imag*R(c[0][1]) + SCALE_BETA_IMAG_M_FRINGE(ymm0,ymm5,ymm15,ymm2); + + xmm5 = _mm_loadu_pd((double const*)(temp_c + ldc * 2));//R(c[0][2]) I(c[0][2]) + ymm0 = _mm256_insertf128_pd(ymm0,xmm5,0);//R(c[0][2]) I(c[0][2]) + //ymm7+=beta_imag*(-I(c[0][2])) beta_imag*R(c[0][2]) + SCALE_BETA_IMAG_M_FRINGE(ymm0,ymm7,ymm15,ymm2); + + xmm5 = _mm_loadu_pd((double const*)(temp_c + ldc * 3));//R(c[0][3]) I(c[0][3]) + ymm0 = _mm256_insertf128_pd(ymm0,xmm5,0);//R(c[0][3]) I(c[0][3]) + //ymm9+=beta_imag*(-I(c[0][3])) beta_imag*R(c[0][3]) + SCALE_BETA_IMAG_M_FRINGE(ymm0,ymm9,ymm15,ymm2); + + xmm5 = _mm_loadu_pd((double const*)(temp_c + ldc * 4));//R(c[0][4]) I(c[0][4]) + ymm0 = _mm256_insertf128_pd(ymm0,xmm5,0);//R(c[0][4]) I(c[0][4]) + //ymm11+=beta_imag*(-I(c[0][4])) beta_imag*R(c[0][4]) + SCALE_BETA_IMAG_M_FRINGE(ymm0,ymm11,ymm15,ymm2); + + xmm5 = _mm_loadu_pd((double const*)(temp_c + ldc * 5));//R(c[0][5]) I(c[0][5]) + ymm0 = _mm256_insertf128_pd(ymm0,xmm5,0);//R(c[0][5]) I(c[0][5]) + //ymm13+=beta_imag*(-I(c[0][5])) beta_imag*R(c[0][5]) + SCALE_BETA_IMAG_M_FRINGE(ymm0,ymm13,ymm15,ymm2); + } + + xmm5 = _mm256_extractf128_pd(ymm3, 0); + _mm_storeu_pd((double *)(temp_c), xmm5); + + xmm5 = _mm256_extractf128_pd(ymm5, 0); + _mm_storeu_pd((double *)(temp_c + ldc), xmm5); + + xmm5 = _mm256_extractf128_pd(ymm7, 0); + _mm_storeu_pd((double *)(temp_c + ldc*2), xmm5); + + xmm5 = _mm256_extractf128_pd(ymm9, 0); + _mm_storeu_pd((double *)(temp_c + ldc*3), xmm5); + + xmm5 = _mm256_extractf128_pd(ymm11, 0); + _mm_storeu_pd((double *)(temp_c + ldc*4), xmm5); + + xmm5 = _mm256_extractf128_pd(ymm13, 0); + _mm_storeu_pd((double *)(temp_c + ldc*5), xmm5); + + } + + } + + //Fringe case for N + if(n_remainder>=4) + { + dcomplex* temp_b = b + (n - n_remainder)*ldb; + dcomplex* temp_a = a; + dcomplex* temp_c = c + (n - n_remainder)*ldc; + + //Main loop for M + for(dim_t i = 0;i < (m-Z_MR+1);i=i+Z_MR) + { + ymm3 = _mm256_setzero_pd(); + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + + /* + a. Perform alpha*A*B using temp_a, temp_b and alpha_real, alpha_imag + where alpha_real and/or alpha_imag is not zero. + b. This loop operates with 4x6 block size + along n dimension for every Z_NR columns of temp_b where + computing all Z_MR rows of temp_a. + c. Same approach is used in remaining fringe cases. + */ + + //R(a[0][0]) I(a[0][0]) R(a[1][0]) I(a[1][0]) + ymm0 = _mm256_loadu_pd((double const *)(temp_a)); + //R(a[2][0]) I(a[2][0]) R(a[3][0]) I(a[3][0]) + ymm1 = _mm256_loadu_pd((double const *)(temp_a + 2)); + + ymm13 = ymm0; + ymm14 = ymm1; + SCALE_ALPHA_REAL_M_LOOP(ymm0,ymm1,ymm15,alpha_real); + SCALE_ALPHA_IMAG_M_LOOP(ymm0,ymm1,ymm13,ymm14,ymm15,ymm2,alpha_imag); + + /* + The result after scaling with alpha_real and/or alpha_imag is as follows: + For ymm0 : + R(a[0][0]) = alpha_real*R(a[0][0])-alpha_imag*I(a[0][0]) + I(a[0][0]) = alpha_real*I(a[0][0])+alpha_imag*R[0][0] + R(a[1][0]) = alpha_real*R(a[1][0])-alpha_imag*I(a[1][0]) + I(a[1][0]) = alpha_real*I(a[1][0])+alpha_imag*(R[1][0]) + + For ymm1 : + R(a[2][0]) = alpha_real*R(a[2][0])-alpha_imag*I(a[2][0]) + I(a[2][0]) = alpha_real*I(a[2][0])+alpha_imag*R[2][0] + R(a[3][0]) = alpha_real*R(a[3][0])-alpha_imag*I(a[3][0]) + I(a[3][0]) = alpha_real*I(a[3][0])+alpha_imag*(R[3][0]) + */ + + //Calculating using real part of complex number in B matrix + FMA_M_LOOP(ymm0,ymm1,ymm3,ymm4,ymm2,(double const *)(temp_b)); + FMA_M_LOOP(ymm0,ymm1,ymm5,ymm6,ymm2,(double const *)(temp_b+ldb)); + FMA_M_LOOP(ymm0,ymm1,ymm7,ymm8,ymm2,(double const *)(temp_b+ldb*2)); + FMA_M_LOOP(ymm0,ymm1,ymm9,ymm10,ymm2,(double const *)(temp_b+ldb*3)); + + //Calculating using imaginary part of complex numbers in B matrix + //Shuffling ymm0 and ymm1 in accordance to the requirement + NEG_PERM_M_LOOP(ymm0,ymm1,ymm2); + FMA_M_LOOP(ymm0,ymm1,ymm3,ymm4,ymm2,(double const *)(temp_b)+1); + FMA_M_LOOP(ymm0,ymm1,ymm5,ymm6,ymm2,(double const *)(temp_b+ldb)+1); + FMA_M_LOOP(ymm0,ymm1,ymm7,ymm8,ymm2,(double const *)(temp_b+ldb*2)+1); + FMA_M_LOOP(ymm0,ymm1,ymm9,ymm10,ymm2,(double const *)(temp_b+ldb*3)+1); + + /* + a. Perform beta*C using temp_c, beta_real, + where beta_real is not zero. + b. This loop operates with 4x6 block size + along n dimension for every Z_NR columns of temp_c where + computing all Z_MR rows of temp_c. + c. Accumulated alpha*A*B into registers will be added to beta*C + d. Same approach is used in remaining fringe cases. + */ + if(beta_real != 0.0) + { + ymm15 = _mm256_broadcast_sd((double const *)(&beta_real)); + + //R(c[0][0]) I(c[0][0]) R(c[1][0]) I(c[1][0]) + ymm0 = _mm256_loadu_pd((double const *)(temp_c)); + //R(c[2][0]) I(c[2][0]) R(c[3][0]) I(c[3][0]) + ymm1 = _mm256_loadu_pd((double const *)(temp_c + 2)); + //ymm3+=beta_real*R(c[0][0]) beta_real*I(c[0][0]) + // beta_real*R(c[1][0]) beta_real*I(c[1][0]) + //ymm4+=beta_real*R(c[2][0]) beta_real*I(c[2][0]) + // beta_real*R(c[3][0]) beta_real*I(c[3][0]) + SCALE_BETA_REAL_M_LOOP(ymm0,ymm1,ymm3,ymm4,ymm15); + + //R(c[0][1]) I(c[0][1]) R(c[1][1]) I(c[1][1]) + ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc)); + //R(c[2][1]) I(c[2][1]) R(c[3][1]) I(c[3][1]) + ymm1 = _mm256_loadu_pd((double const *)(temp_c + ldc + 2)); + //ymm5+=beta_real*R(c[0][1]) beta_real*I(c[0][1]) + // beta_real*R(c[1][1]) beta_real*I(c[1][1]) + //ymm6+=beta_real*R(c[2][1]) beta_real*I(c[2][1]) + // beta_real*R(c[3][1]) beta_real*I(c[3][1]) + SCALE_BETA_REAL_M_LOOP(ymm0,ymm1,ymm5,ymm6,ymm15); + + //R(c[0][2]) I(c[0][2]) R(c[1][2]) I(c[1][2]) + ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc*2)); + //R(c[2][2]) I(c[2][2]) R(c[3][2]) I(c[3][2]) + ymm1 = _mm256_loadu_pd((double const *)(temp_c + ldc*2 + 2)); + //ymm7+=beta_real*R(c[0][2]) beta_real*I(c[0][2]) + // beta_real*R(c[1][2]) beta_real*I(c[1][2]) + //ymm8+=beta_real*R(c[2][2]) beta_real*I(c[2][2]) + // beta_real*R(c[3][2]) beta_real*I(c[3][2]) + SCALE_BETA_REAL_M_LOOP(ymm0,ymm1,ymm7,ymm8,ymm15); + + //R(c[0][3]) I(c[0][3]) R(c[1][3]) I(c[1][3]) + ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc*3)); + //R(c[2][3]) I(c[2][3]) R(c[3][3]) I(c[3][3]) + ymm1 = _mm256_loadu_pd((double const *)(temp_c + ldc*3 + 2)); + //ymm9+=beta_real*R(c[0][3]) beta_real*I(c[0][3]) + // beta_real*R(c[1][3]) beta_real*I(c[1][3]) + //ymm10+=beta_real*R(c[2][3]) beta_real*I(c[2][3]) + // beta_real*R(c[3][3]) beta_real*I(c[3][3]) + SCALE_BETA_REAL_M_LOOP(ymm0,ymm1,ymm9,ymm10,ymm15); + } + /* + a. Perform beta*C using temp_c, beta_imag, + where beta_imag is not zero. + b. This loop operates with 4x6 block size + along n dimension for every Z_NR columns of temp_c where + computing all Z_MR rows of temp_c. + c. Accumulated alpha*A*B into registers will be added to beta*C + d. Same approach is used in remaining fringe cases. + */ + + if(beta_imag != 0.0) + { + ymm15 = _mm256_broadcast_sd((double const *)(&beta_imag)); + + //R(c[0][0]) I(c[0][0]) R(c[1][0]) I(c[1][0]) + ymm0 = _mm256_loadu_pd((double const *)(temp_c)); + //R(c[2][0]) I(c[2][0]) R(c[3][0]) I(c[3][0]) + ymm1 = _mm256_loadu_pd((double const *)(temp_c + 2)); + //ymm3+=beta_imag*(-I(c[0][0])) beta_imag*R(c[0][0]) + // beta_imag*(-I(c[1][0])) beta_imag*R(c[1][0]) + //ymm4+=beta_imag*(-I(c[2][0])) beta_imag*R(c[2][0]) + // beta_imag*(-I(c[3][0])) beta_imag*R(c[3][0]) + SCALE_BETA_IMAG_M_LOOP(ymm0,ymm1,ymm3,ymm4,ymm15,ymm2); + + //R(c[0][1]) I(c[0][1]) R(c[1][1]) I(c[1][1]) + ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc)); + //R(c[2][1]) I(c[2][1]) R(c[3][1]) I(c[3][1]) + ymm1 = _mm256_loadu_pd((double const *)(temp_c + ldc + 2)); + //ymm5+=beta_imag*(-I(c[0][1])) beta_imag*R(c[0][1]) + // beta_imag*(-I(c[1][1])) beta_imag*R(c[1][1]) + //ymm6+=beta_imag*(-I(c[2][1])) beta_imag*R(c[2][1]) + // beta_imag*(-I(c[3][1])) beta_imag*R(c[3][1]) + SCALE_BETA_IMAG_M_LOOP(ymm0,ymm1,ymm5,ymm6,ymm15,ymm2); + + //R(c[0][2]) I(c[0][2]) R(c[1][2]) I(c[1][2]) + ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc*2)); + //R(c[2][2]) I(c[2][2]) R(c[3][2]) I(c[3][2]) + ymm1 = _mm256_loadu_pd((double const *)(temp_c + ldc*2 + 2)); + //ymm7+=beta_imag*(-I(c[0][2])) beta_imag*R(c[0][2]) + // beta_imag*(-I(c[1][2])) beta_imag*R(c[1][2]) + //ymm8+=beta_imag*(-I(c[2][2])) beta_imag*R(c[2][2]) + // beta_imag*(-I(c[3][2])) beta_imag*R(c[3][2]) + SCALE_BETA_IMAG_M_LOOP(ymm0,ymm1,ymm7,ymm8,ymm15,ymm2); + + //R(c[0][3]) I(c[0][3]) R(c[1][3]) I(c[1][3]) + ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc*3)); + //R(c[2][3]) I(c[2][3]) R(c[3][3]) I(c[3][3]) + ymm1 = _mm256_loadu_pd((double const *)(temp_c + ldc*3 + 2)); + //ymm9+=beta_imag*(-I(c[0][3])) beta_imag*R(c[0][3]) + // beta_imag*(-I(c[1][3])) beta_imag*R(c[1][3]) + //ymm10+=beta_imag*(-I(c[2][3])) beta_imag*R(c[2][3]) + // beta_imag*(-I(c[3][3])) beta_imag*R(c[3][3]) + SCALE_BETA_IMAG_M_LOOP(ymm0,ymm1,ymm9,ymm10,ymm15,ymm2); + } + /* + The scaling has been done sequentially as follows: + - If alpha_real is not 0, it is used for scaling A + - If alpha_imag is not 0, it is used for scaling A using permutation + and selective negation, after loading + - If beta_real is not 0, is is used for scaling C + - If beta_imag is not 0, it is used for scaling C using permutation + and selective negation, after loading + + The results are accumalated in accordance to the non zero scalar values, + and similar approach is followed in fringe cases + */ + + _mm256_storeu_pd((double *)(temp_c), ymm3); + _mm256_storeu_pd((double *)(temp_c + 2), ymm4); + + _mm256_storeu_pd((double *)(temp_c + ldc), ymm5); + _mm256_storeu_pd((double *)(temp_c + ldc + 2), ymm6); + + _mm256_storeu_pd((double *)(temp_c + ldc*2), ymm7); + _mm256_storeu_pd((double *)(temp_c + ldc*2 + 2), ymm8); + + _mm256_storeu_pd((double *)(temp_c + ldc*3), ymm9); + _mm256_storeu_pd((double *)(temp_c + ldc*3 + 2), ymm10); + + temp_c+=Z_MR; + temp_a+=Z_MR; + } + + // Fringe cases for M + dim_t m_rem=m_remainder; + if(m_rem>=2) + { + ymm3 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + + + //R(a[0][0]) I(a[0][0]) R(a[1][0]) I(a[1][0]) + ymm0 = _mm256_loadu_pd((double const *)(temp_a)); + + ymm13 = ymm0; + SCALE_ALPHA_REAL_M_FRINGE(ymm0,ymm15,alpha_real); + SCALE_ALPHA_IMAG_M_FRINGE(ymm0,ymm13,ymm15,ymm2,alpha_imag); + /* + The result after scaling with alpha_real and/or alpha_imag is as follows: + For ymm0 : + R(a[0][0]) = alpha_real*R(a[0][0])-alpha_imag*I(a[0][0]) + I(a[0][0]) = alpha_real*I(a[0][0])+alpha_imag*R[0][0] + R(a[1][0]) = alpha_real*R(a[1][0])-alpha_imag*I(a[1][0]) + I(a[1][0]) = alpha_real*I(a[1][0])+alpha_imag*(R[1][0]) + */ + + //Calculating using real part of complex number in B matrix + //ymm3+=R(b[0][0])*R(a[0][0]) R(b[0][0])*I(a[0][0]) + // R(b[0][0])*R(a[1][0]) R(b[0][0])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm3,ymm2,(double const *)(temp_b)); + //ymm5+=R(b[0][1])*R(a[0][0]) R(b[0][1])*I(a[0][0]) + // R(b[0][1])*R(a[1][0]) R(b[0][1])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm5,ymm2,(double const *)(temp_b+ldb)); + //ymm7+=R(b[0][2])*R(a[0][0]) R(b[0][2])*I(a[0][0]) + // R(b[0][2])*R(a[1][0]) R(b[0][2])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm7,ymm2,(double const *)(temp_b+ldb*2)); + //ymm9+=R(b[0][3])*R(a[0][0]) R(b[0][3])*I(a[0][0]) + // R(b[0][3])*R(a[1][0]) R(b[0][3])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm9,ymm2,(double const *)(temp_b+ldb*3)); + + //Calculating using imaginary part of complex numbers in B matrix + //Shuffling ymm0 in accordance to the requirement + NEG_PERM_M_FRINGE(ymm0,ymm2); + + // ymm3+=I(b[0][0])*R(a[0][0]) I(b[0][0])*I(a[0][0]) + // I(b[0][0])*R(a[1][0]) I(b[0][0])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm3,ymm2,(double const *)(temp_b)+1); + //ymm5+=I(b[0][1])*R(a[0][0]) I(b[0][1])*I(a[0][0]) + // I(b[0][1])*R(a[1][0]) I(b[0][1])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm5,ymm2,(double const *)(temp_b+ldb)+1); + //ymm7+=I(b[0][2])*R(a[0][0]) I(b[0][2])*I(a[0][0]) + // I(b[0][2])*R(a[1][0]) I(b[0][2])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm7,ymm2,(double const *)(temp_b+ldb*2)+1); + //ymm9+=I(b[0][3])*R(a[0][0]) I(b[0][3])*I(a[0][0]) + // I(b[0][3])*R(a[1][0]) I(b[0][3])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm9,ymm2,(double const *)(temp_b+ldb*3)+1); + + + if(beta_real != 0.0) + { + ymm15 = _mm256_broadcast_sd((double const *)(&beta_real)); + + //R(c[0][0]) I(c[0][0]) R(c[1][0]) I(c[1][0]) + ymm0 = _mm256_loadu_pd((double const *)(temp_c)); + //ymm3+=beta_real*R(c[0][0]) beta_real*I(c[0][0]) + // beta_real*R(c[1][0]) beta_real*I(c[1][0]) + SCALE_BETA_REAL_M_FRINGE(ymm0,ymm3,ymm15); + + //R(c[0][1]) I(c[0][1]) R(c[1][1]) I(c[1][1]) + ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc)); + //ymm5+=beta_real*R(c[0][1]) beta_real*I(c[0][1]) + // beta_real*R(c[1][1]) beta_real*I(c[1][1]) + SCALE_BETA_REAL_M_FRINGE(ymm0,ymm5,ymm15); + + //R(c[0][2]) I(c[0][2]) R(c[1][2]) I(c[1][2]) + ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc*2)); + //ymm7+=beta_real*R(c[0][2]) beta_real*I(c[0][2]) + // beta_real*R(c[1][2]) beta_real*I(c[1][2]) + SCALE_BETA_REAL_M_FRINGE(ymm0,ymm7,ymm15); + + //R(c[0][3]) I(c[0][3]) R(c[1][3]) I(c[1][3]) + ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc*3)); + //ymm9+=beta_real*R(c[0][3]) beta_real*I(c[0][3]) + // beta_real*R(c[1][3]) beta_real*I(c[1][3]) + SCALE_BETA_REAL_M_FRINGE(ymm0,ymm9,ymm15); + } + + if(beta_imag != 0.0) + { + ymm15 = _mm256_broadcast_sd((double const *)(&beta_imag)); + + //R(c[0][0]) I(c[0][0]) R(c[1][0]) I(c[1][0]) + ymm0 = _mm256_loadu_pd((double const *)(temp_c)); + //ymm3+=beta_imag*(-I(c[0][0])) beta_imag*R(c[0][0]) + // beta_imag*(-I(c[1][0])) beta_imag*R(c[1][0]) + SCALE_BETA_IMAG_M_FRINGE(ymm0,ymm3,ymm15,ymm2); + + //R(c[0][1]) I(c[0][1]) R(c[1][1]) I(c[1][1]) + ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc)); + //ymm5+=beta_imag*(-I(c[0][1])) beta_imag*R(c[0][1]) + // beta_imag*(-I(c[1][1])) beta_imag*R(c[1][1]) + SCALE_BETA_IMAG_M_FRINGE(ymm0,ymm5,ymm15,ymm2); + + //R(c[0][2]) I(c[0][2]) R(c[1][2]) I(c[1][2]) + ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc*2)); + //ymm7+=beta_imag*(-I(c[0][2])) beta_imag*R(c[0][2]) + // beta_imag*(-I(c[1][2])) beta_imag*R(c[1][2]) + SCALE_BETA_IMAG_M_FRINGE(ymm0,ymm7,ymm15,ymm2); + + //R(c[0][3]) I(c[0][3]) R(c[1][3]) I(c[1][3]) + ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc*3)); + //ymm9+=beta_imag*(-I(c[0][3])) beta_imag*R(c[0][3]) + // beta_imag*(-I(c[1][3])) beta_imag*R(c[1][3]) + SCALE_BETA_IMAG_M_FRINGE(ymm0,ymm9,ymm15,ymm2); + } + + /* + The scaling has been done sequentially as follows: + - If alpha_real is not 0, it is used for scaling A + - If alpha_imag is not 0, it is used for scaling A using permutation + and selective negation, after loading + - If beta_real is not 0, is is used for scaling C + - If beta_imag is not 0, it is used for scaling C using permutation + and selective negation, after loading + + The results are accumalated in accordance to the non zero scalar values, + and similar approach is followed in fringe cases + */ + + _mm256_storeu_pd((double *)(temp_c), ymm3); + _mm256_storeu_pd((double *)(temp_c + ldc), ymm5); + _mm256_storeu_pd((double *)(temp_c + ldc*2), ymm7); + _mm256_storeu_pd((double *)(temp_c + ldc*3), ymm9); + + temp_c+=2; + temp_a+=2; + + m_rem -= 2; + } + + if(m_rem==1) + { + + xmm5 = _mm_setzero_pd(); + ymm3 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + + xmm5 = _mm_loadu_pd((double const*)(temp_a));//R(a[0][0]) I(a[0][0]) + ymm0 = _mm256_insertf128_pd(ymm0,xmm5,0);//R(a[0][0]) I(a[0][0]) + + ymm13 = ymm0; + SCALE_ALPHA_REAL_M_FRINGE(ymm0,ymm15,alpha_real); + SCALE_ALPHA_IMAG_M_FRINGE(ymm0,ymm13,ymm15,ymm2,alpha_imag); + + //Calculating using real part of complex number in B matrix + //ymm3+=R(b[0][0])*R(a[0][0]) R(b[0][0])*I(a[0][0]) + // R(b[0][0])*R(a[1][0]) R(b[0][0])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm3,ymm2,(double const *)(temp_b)); + //ymm5+=R(b[0][1])*R(a[0][0]) R(b[0][1])*I(a[0][0]) + // R(b[0][1])*R(a[1][0]) R(b[0][1])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm5,ymm2,(double const *)(temp_b+ldb)); + //ymm7+=R(b[0][2])*R(a[0][0]) R(b[0][2])*I(a[0][0]) + // R(b[0][2])*R(a[1][0]) R(b[0][2])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm7,ymm2,(double const *)(temp_b+ldb*2)); + //ymm9+=R(b[0][3])*R(a[0][0]) R(b[0][3])*I(a[0][0]) + // R(b[0][3])*R(a[1][0]) R(b[0][3])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm9,ymm2,(double const *)(temp_b+ldb*3)); + + //Calculating using imaginary part of complex numbers in B matrix + //Shuffling ymm0 in accordance to the requirement + NEG_PERM_M_FRINGE(ymm0,ymm2); + + // ymm3+=I(b[0][0])*R(a[0][0]) I(b[0][0])*I(a[0][0]) + // I(b[0][0])*R(a[1][0]) I(b[0][0])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm3,ymm2,(double const *)(temp_b)+1); + //ymm5+=I(b[0][1])*R(a[0][0]) I(b[0][1])*I(a[0][0]) + // I(b[0][1])*R(a[1][0]) I(b[0][1])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm5,ymm2,(double const *)(temp_b+ldb)+1); + //ymm7+=I(b[0][2])*R(a[0][0]) I(b[0][2])*I(a[0][0]) + // I(b[0][2])*R(a[1][0]) I(b[0][2])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm7,ymm2,(double const *)(temp_b+ldb*2)+1); + //ymm9+=I(b[0][3])*R(a[0][0]) I(b[0][3])*I(a[0][0]) + // I(b[0][3])*R(a[1][0]) I(b[0][3])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm9,ymm2,(double const *)(temp_b+ldb*3)+1); + + if(beta_real != 0.0) + { + ymm15 = _mm256_broadcast_sd((double const *)(&beta_real)); + + xmm5 = _mm_loadu_pd((double const*)(temp_c));//R(c[0][0]) I(c[0][0]) + ymm0 = _mm256_insertf128_pd(ymm0,xmm5,0);//R(c[0][0]) I(c[0][0]) + //ymm3+=beta_real*R(c[0][0]) beta_real*I(c[0][0]) + SCALE_BETA_REAL_M_FRINGE(ymm0,ymm3,ymm15); + + xmm5 = _mm_loadu_pd((double const*)(temp_c + ldc));//R(c[0][1]) I(c[0][1]) + ymm0 = _mm256_insertf128_pd(ymm0,xmm5,0);//R(c[0][1]) I(c[0][1]) + //ymm5+=beta_real*R(c[0][1]) beta_real*I(c[0][1]) + SCALE_BETA_REAL_M_FRINGE(ymm0,ymm5,ymm15); + + xmm5 = _mm_loadu_pd((double const*)(temp_c + ldc * 2));//R(c[0][2]) I(c[0][2]) + ymm0 = _mm256_insertf128_pd(ymm0,xmm5,0);//R(c[0][2]) I(c[0][2]) + //ymm7+=beta_real*R(c[0][2]) beta_real*I(c[0][2]) + SCALE_BETA_REAL_M_FRINGE(ymm0,ymm7,ymm15); + + xmm5 = _mm_loadu_pd((double const*)(temp_c + ldc * 3));//R(c[0][3]) I(c[0][3]) + ymm0 = _mm256_insertf128_pd(ymm0,xmm5,0);//R(c[0][3]) I(c[0][3]) + //ymm9+=beta_real*R(c[0][3]) beta_real*I(c[0][3]) + SCALE_BETA_REAL_M_FRINGE(ymm0,ymm9,ymm15); + } + + if(beta_imag != 0.0) + { + ymm15 = _mm256_broadcast_sd((double const *)(&beta_imag)); + + xmm5 = _mm_loadu_pd((double const*)(temp_c));//R(c[0][0]) I(c[0][0]) + ymm0 = _mm256_insertf128_pd(ymm0,xmm5,0);//R(c[0][0]) I(c[0][0]) + //ymm3+=beta_imag*(-I(c[0][0])) beta_imag*R(c[0][0]) + SCALE_BETA_IMAG_M_FRINGE(ymm0,ymm3,ymm15,ymm2); + + xmm5 = _mm_loadu_pd((double const*)(temp_c + ldc));//R(c[0][1]) I(c[0][1]) + ymm0 = _mm256_insertf128_pd(ymm0,xmm5,0);//R(c[0][1]) I(c[0][1]) + //ymm5+=beta_imag*(-I(c[0][1])) beta_imag*R(c[0][1]) + SCALE_BETA_IMAG_M_FRINGE(ymm0,ymm5,ymm15,ymm2); + + xmm5 = _mm_loadu_pd((double const*)(temp_c + ldc * 2));//R(c[0][2]) I(c[0][2]) + ymm0 = _mm256_insertf128_pd(ymm0,xmm5,0);//R(c[0][2]) I(c[0][2]) + //ymm7+=beta_imag*(-I(c[0][2])) beta_imag*R(c[0][2]) + SCALE_BETA_IMAG_M_FRINGE(ymm0,ymm7,ymm15,ymm2); + + xmm5 = _mm_loadu_pd((double const*)(temp_c + ldc * 3));//R(c[0][3]) I(c[0][3]) + ymm0 = _mm256_insertf128_pd(ymm0,xmm5,0);//R(c[0][3]) I(c[0][3]) + //ymm9+=beta_imag*(-I(c[0][3])) beta_imag*R(c[0][3]) + SCALE_BETA_IMAG_M_FRINGE(ymm0,ymm9,ymm15,ymm2); + } + + xmm5 = _mm256_extractf128_pd(ymm3, 0); + _mm_storeu_pd((double *)(temp_c), xmm5); + + xmm5 = _mm256_extractf128_pd(ymm5, 0); + _mm_storeu_pd((double *)(temp_c + ldc), xmm5); + + xmm5 = _mm256_extractf128_pd(ymm7, 0); + _mm_storeu_pd((double *)(temp_c + ldc*2), xmm5); + + xmm5 = _mm256_extractf128_pd(ymm9, 0); + _mm_storeu_pd((double *)(temp_c + ldc*3), xmm5); + + } + n_remainder -= 4; + + } + if(n_remainder>=2) + { + dcomplex* temp_b = b + (n - n_remainder)*ldb; + dcomplex* temp_a = a; + dcomplex* temp_c = c + (n - n_remainder)*ldc; + for(dim_t i = 0;i < (m-Z_MR+1);i=i+Z_MR) + { + ymm3 = _mm256_setzero_pd(); + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + + /* + a. Perform alpha*A*B using temp_a, temp_b and alpha_real, alpha_imag + where alpha_real and/or alpha_imag is not zero. + b. This loop operates with 4x6 block size + along n dimension for every Z_NR columns of temp_b where + computing all Z_MR rows of temp_a. + c. Same approach is used in remaining fringe cases. + */ + + //R(a[0][0]) I(a[0][0]) R(a[1][0]) I(a[1][0]) + ymm0 = _mm256_loadu_pd((double const *)(temp_a)); + //R(a[2][0]) I(a[2][0]) R(a[3][0]) I(a[3][0]) + ymm1 = _mm256_loadu_pd((double const *)(temp_a + 2)); + + ymm13 = ymm0; + ymm14 = ymm1; + SCALE_ALPHA_REAL_M_LOOP(ymm0,ymm1,ymm15,alpha_real); + SCALE_ALPHA_IMAG_M_LOOP(ymm0,ymm1,ymm13,ymm14,ymm15,ymm2,alpha_imag); + + /* + The result after scaling with alpha_real and/or alpha_imag is as follows: + For ymm0 : + R(a[0][0]) = alpha_real*R(a[0][0])-alpha_imag*I(a[0][0]) + I(a[0][0]) = alpha_real*I(a[0][0])+alpha_imag*R[0][0] + R(a[1][0]) = alpha_real*R(a[1][0])-alpha_imag*I(a[1][0]) + I(a[1][0]) = alpha_real*I(a[1][0])+alpha_imag*(R[1][0]) + + For ymm1 : + R(a[2][0]) = alpha_real*R(a[2][0])-alpha_imag*I(a[2][0]) + I(a[2][0]) = alpha_real*I(a[2][0])+alpha_imag*R[2][0] + R(a[3][0]) = alpha_real*R(a[3][0])-alpha_imag*I(a[3][0]) + I(a[3][0]) = alpha_real*I(a[3][0])+alpha_imag*(R[3][0]) + */ + + //Calculating using real part of complex number in B matrix + FMA_M_LOOP(ymm0,ymm1,ymm3,ymm4,ymm2,(double const *)(temp_b)); + FMA_M_LOOP(ymm0,ymm1,ymm5,ymm6,ymm2,(double const *)(temp_b+ldb)); + + //Calculating using imaginary part of complex numbers in B matrix + //Shuffling ymm0 and ymm1 in accordance to the requirement + NEG_PERM_M_LOOP(ymm0,ymm1,ymm2); + FMA_M_LOOP(ymm0,ymm1,ymm3,ymm4,ymm2,(double const *)(temp_b)+1); + FMA_M_LOOP(ymm0,ymm1,ymm5,ymm6,ymm2,(double const *)(temp_b+ldb)+1); + + /* + a. Perform beta*C using temp_c, beta_real, + where beta_real is not zero. + b. This loop operates with 4x6 block size + along n dimension for every Z_NR columns of temp_c where + computing all Z_MR rows of temp_c. + c. Accumulated alpha*A*B into registers will be added to beta*C + d. Same approach is used in remaining fringe cases. + */ + if(beta_real != 0.0) + { + ymm15 = _mm256_broadcast_sd((double const *)(&beta_real)); + + //R(c[0][0]) I(c[0][0]) R(c[1][0]) I(c[1][0]) + ymm0 = _mm256_loadu_pd((double const *)(temp_c)); + //R(c[2][0]) I(c[2][0]) R(c[3][0]) I(c[3][0]) + ymm1 = _mm256_loadu_pd((double const *)(temp_c + 2)); + //ymm3+=beta_real*R(c[0][0]) beta_real*I(c[0][0]) + // beta_real*R(c[1][0]) beta_real*I(c[1][0]) + //ymm4+=beta_real*R(c[2][0]) beta_real*I(c[2][0]) + // beta_real*R(c[3][0]) beta_real*I(c[3][0]) + SCALE_BETA_REAL_M_LOOP(ymm0,ymm1,ymm3,ymm4,ymm15); + + //R(c[0][1]) I(c[0][1]) R(c[1][1]) I(c[1][1]) + ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc)); + //R(c[2][1]) I(c[2][1]) R(c[3][1]) I(c[3][1]) + ymm1 = _mm256_loadu_pd((double const *)(temp_c + ldc + 2)); + //ymm5+=beta_real*R(c[0][1]) beta_real*I(c[0][1]) + // beta_real*R(c[1][1]) beta_real*I(c[1][1]) + //ymm6+=beta_real*R(c[2][1]) beta_real*I(c[2][1]) + // beta_real*R(c[3][1]) beta_real*I(c[3][1]) + SCALE_BETA_REAL_M_LOOP(ymm0,ymm1,ymm5,ymm6,ymm15); + } + + /* + a. Perform beta*C using temp_c, beta_imag, + where beta_imag is not zero. + b. This loop operates with 4x6 block size + along n dimension for every Z_NR columns of temp_c where + computing all Z_MR rows of temp_c. + c. Accumulated alpha*A*B into registers will be added to beta*C + d. Same approach is used in remaining fringe cases. + */ + + if(beta_imag != 0.0) + { + ymm15 = _mm256_broadcast_sd((double const *)(&beta_imag)); + + //R(c[0][0]) I(c[0][0]) R(c[1][0]) I(c[1][0]) + ymm0 = _mm256_loadu_pd((double const *)(temp_c)); + //R(c[2][0]) I(c[2][0]) R(c[3][0]) I(c[3][0]) + ymm1 = _mm256_loadu_pd((double const *)(temp_c + 2)); + //ymm3+=beta_imag*(-I(c[0][0])) beta_imag*R(c[0][0]) + // beta_imag*(-I(c[1][0])) beta_imag*R(c[1][0]) + //ymm4+=beta_imag*(-I(c[2][0])) beta_imag*R(c[2][0]) + // beta_imag*(-I(c[3][0])) beta_imag*R(c[3][0]) + SCALE_BETA_IMAG_M_LOOP(ymm0,ymm1,ymm3,ymm4,ymm15,ymm2); + + //R(c[0][1]) I(c[0][1]) R(c[1][1]) I(c[1][1]) + ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc)); + //R(c[2][1]) I(c[2][1]) R(c[3][1]) I(c[3][1]) + ymm1 = _mm256_loadu_pd((double const *)(temp_c + ldc + 2)); + //ymm5+=beta_imag*(-I(c[0][1])) beta_imag*R(c[0][1]) + // beta_imag*(-I(c[1][1])) beta_imag*R(c[1][1]) + //ymm6+=beta_imag*(-I(c[2][1])) beta_imag*R(c[2][1]) + // beta_imag*(-I(c[3][1])) beta_imag*R(c[3][1]) + SCALE_BETA_IMAG_M_LOOP(ymm0,ymm1,ymm5,ymm6,ymm15,ymm2); + } + /* + The scaling has been done sequentially as follows: + - If alpha_real is not 0, it is used for scaling A + - If alpha_imag is not 0, it is used for scaling A using permutation + and selective negation, after loading + - If beta_real is not 0, is is used for scaling C + - If beta_imag is not 0, it is used for scaling C using permutation + and selective negation, after loading + + The results are accumalated in accordance to the non zero scalar values, + and similar approach is followed in fringe cases + */ + + _mm256_storeu_pd((double *)(temp_c), ymm3); + _mm256_storeu_pd((double *)(temp_c + 2), ymm4); + + _mm256_storeu_pd((double *)(temp_c + ldc), ymm5); + _mm256_storeu_pd((double *)(temp_c + ldc + 2), ymm6); + + temp_c+=Z_MR; + temp_a+=Z_MR; + } + + dim_t m_rem=m_remainder; + if(m_rem>=2) + { + ymm3 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + + + //R(a[0][0]) I(a[0][0]) R(a[1][0]) I(a[1][0]) + ymm0 = _mm256_loadu_pd((double const *)(temp_a)); + + ymm13 = ymm0; + SCALE_ALPHA_REAL_M_FRINGE(ymm0,ymm15,alpha_real); + SCALE_ALPHA_IMAG_M_FRINGE(ymm0,ymm13,ymm15,ymm2,alpha_imag); + /* + The result after scaling with alpha_real and/or alpha_imag is as follows: + For ymm0 : + R(a[0][0]) = alpha_real*R(a[0][0])-alpha_imag*I(a[0][0]) + I(a[0][0]) = alpha_real*I(a[0][0])+alpha_imag*R[0][0] + R(a[1][0]) = alpha_real*R(a[1][0])-alpha_imag*I(a[1][0]) + I(a[1][0]) = alpha_real*I(a[1][0])+alpha_imag*(R[1][0]) + */ + + //Calculating using real part of complex number in B matrix + //ymm3+=R(b[0][0])*R(a[0][0]) R(b[0][0])*I(a[0][0]) + // R(b[0][0])*R(a[1][0]) R(b[0][0])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm3,ymm2,(double const *)(temp_b)); + //ymm5+=R(b[0][1])*R(a[0][0]) R(b[0][1])*I(a[0][0]) + // R(b[0][1])*R(a[1][0]) R(b[0][1])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm5,ymm2,(double const *)(temp_b+ldb)); + + //Calculating using imaginary part of complex numbers in B matrix + //Shuffling ymm0 in accordance to the requirement + NEG_PERM_M_FRINGE(ymm0,ymm2); + + // ymm3+=I(b[0][0])*R(a[0][0]) I(b[0][0])*I(a[0][0]) + // I(b[0][0])*R(a[1][0]) I(b[0][0])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm3,ymm2,(double const *)(temp_b)+1); + //ymm5+=I(b[0][1])*R(a[0][0]) I(b[0][1])*I(a[0][0]) + // I(b[0][1])*R(a[1][0]) I(b[0][1])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm5,ymm2,(double const *)(temp_b+ldb)+1); + + + if(beta_real != 0.0) + { + ymm15 = _mm256_broadcast_sd((double const *)(&beta_real)); + + //R(c[0][0]) I(c[0][0]) R(c[1][0]) I(c[1][0]) + ymm0 = _mm256_loadu_pd((double const *)(temp_c)); + //ymm3+=beta_real*R(c[0][0]) beta_real*I(c[0][0]) + // beta_real*R(c[1][0]) beta_real*I(c[1][0]) + SCALE_BETA_REAL_M_FRINGE(ymm0,ymm3,ymm15); + + //R(c[0][1]) I(c[0][1]) R(c[1][1]) I(c[1][1]) + ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc)); + //ymm5+=beta_real*R(c[0][1]) beta_real*I(c[0][1]) + // beta_real*R(c[1][1]) beta_real*I(c[1][1]) + SCALE_BETA_REAL_M_FRINGE(ymm0,ymm5,ymm15); + } + + if(beta_imag != 0.0) + { + ymm15 = _mm256_broadcast_sd((double const *)(&beta_imag)); + + //R(c[0][0]) I(c[0][0]) R(c[1][0]) I(c[1][0]) + ymm0 = _mm256_loadu_pd((double const *)(temp_c)); + //ymm3+=beta_imag*(-I(c[0][0])) beta_imag*R(c[0][0]) + // beta_imag*(-I(c[1][0])) beta_imag*R(c[1][0]) + SCALE_BETA_IMAG_M_FRINGE(ymm0,ymm3,ymm15,ymm2); + + //R(c[0][1]) I(c[0][1]) R(c[1][1]) I(c[1][1]) + ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc)); + //ymm5+=beta_imag*(-I(c[0][1])) beta_imag*R(c[0][1]) + // beta_imag*(-I(c[1][1])) beta_imag*R(c[1][1]) + SCALE_BETA_IMAG_M_FRINGE(ymm0,ymm5,ymm15,ymm2); + } + + /* + The scaling has been done sequentially as follows: + - If alpha_real is not 0, it is used for scaling A + - If alpha_imag is not 0, it is used for scaling A using permutation + and selective negation, after loading + - If beta_real is not 0, is is used for scaling C + - If beta_imag is not 0, it is used for scaling C using permutation + and selective negation, after loading + + The results are accumalated in accordance to the non zero scalar values, + and similar approach is followed in fringe cases + */ + + _mm256_storeu_pd((double *)(temp_c), ymm3); + _mm256_storeu_pd((double *)(temp_c + ldc), ymm5); + + temp_c+=2; + temp_a+=2; + + m_rem -= 2; + } + + if(m_rem==1) + { + + xmm5 = _mm_setzero_pd(); + ymm3 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + + xmm5 = _mm_loadu_pd((double const*)(temp_a));//R(a[0][0]) I(a[0][0]) + ymm0 = _mm256_insertf128_pd(ymm0,xmm5,0);//R(a[0][0]) I(a[0][0]) + + ymm13 = ymm0; + SCALE_ALPHA_REAL_M_FRINGE(ymm0,ymm15,alpha_real); + SCALE_ALPHA_IMAG_M_FRINGE(ymm0,ymm13,ymm15,ymm2,alpha_imag); + + //Calculating using real part of complex number in B matrix + //ymm3+=R(b[0][0])*R(a[0][0]) R(b[0][0])*I(a[0][0]) + // R(b[0][0])*R(a[1][0]) R(b[0][0])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm3,ymm2,(double const *)(temp_b)); + //ymm5+=R(b[0][1])*R(a[0][0]) R(b[0][1])*I(a[0][0]) + // R(b[0][1])*R(a[1][0]) R(b[0][1])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm5,ymm2,(double const *)(temp_b+ldb)); + + //Calculating using imaginary part of complex numbers in B matrix + //Shuffling ymm0 in accordance to the requirement + NEG_PERM_M_FRINGE(ymm0,ymm2); + + // ymm3+=I(b[0][0])*R(a[0][0]) I(b[0][0])*I(a[0][0]) + // I(b[0][0])*R(a[1][0]) I(b[0][0])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm3,ymm2,(double const *)(temp_b)+1); + //ymm5+=I(b[0][1])*R(a[0][0]) I(b[0][1])*I(a[0][0]) + // I(b[0][1])*R(a[1][0]) I(b[0][1])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm5,ymm2,(double const *)(temp_b+ldb)+1); + + if(beta_real != 0.0) + { + ymm15 = _mm256_broadcast_sd((double const *)(&beta_real)); + + xmm5 = _mm_loadu_pd((double const*)(temp_c));//R(c[0][0]) I(c[0][0]) + ymm0 = _mm256_insertf128_pd(ymm0,xmm5,0);//R(c[0][0]) I(c[0][0]) + //ymm3+=beta_real*R(c[0][0]) beta_real*I(c[0][0]) + SCALE_BETA_REAL_M_FRINGE(ymm0,ymm3,ymm15); + + xmm5 = _mm_loadu_pd((double const*)(temp_c + ldc));//R(c[0][1]) I(c[0][1]) + ymm0 = _mm256_insertf128_pd(ymm0,xmm5,0);//R(c[0][1]) I(c[0][1]) + //ymm5+=beta_real*R(c[0][1]) beta_real*I(c[0][1]) + SCALE_BETA_REAL_M_FRINGE(ymm0,ymm5,ymm15); + } + + if(beta_imag != 0.0) + { + ymm15 = _mm256_broadcast_sd((double const *)(&beta_imag)); + + xmm5 = _mm_loadu_pd((double const*)(temp_c));//R(c[0][0]) I(c[0][0]) + ymm0 = _mm256_insertf128_pd(ymm0,xmm5,0);//R(c[0][0]) I(c[0][0]) + //ymm3+=beta_imag*(-I(c[0][0])) beta_imag*R(c[0][0]) + SCALE_BETA_IMAG_M_FRINGE(ymm0,ymm3,ymm15,ymm2); + + xmm5 = _mm_loadu_pd((double const*)(temp_c + ldc));//R(c[0][1]) I(c[0][1]) + ymm0 = _mm256_insertf128_pd(ymm0,xmm5,0);//R(c[0][1]) I(c[0][1]) + //ymm5+=beta_imag*(-I(c[0][1])) beta_imag*R(c[0][1]) + SCALE_BETA_IMAG_M_FRINGE(ymm0,ymm5,ymm15,ymm2); + } + + xmm5 = _mm256_extractf128_pd(ymm3, 0); + _mm_storeu_pd((double *)(temp_c), xmm5); + + xmm5 = _mm256_extractf128_pd(ymm5, 0); + _mm_storeu_pd((double *)(temp_c + ldc), xmm5); + + } + n_remainder -= 2; + } + if(n_remainder==1) + { + dcomplex* temp_b = b + (n - n_remainder)*ldb; + dcomplex* temp_a = a; + dcomplex* temp_c = c + (n - n_remainder)*ldc; + + // Main loop for M + for(dim_t i = 0;i < (m-Z_MR+1);i=i+Z_MR) + { + ymm3 = _mm256_setzero_pd(); + ymm4 = _mm256_setzero_pd(); + + + /* + a. Perform alpha*A*B using temp_a, temp_b and alpha_real, aplha_vali + where alpha_real and/or alpha_imag is not zero. + b. This loop operates with 4x6 block size + along n dimension for every Z_NR columns of temp_b where + computing all Z_MR rows of temp_a. + c. Same approach is used in remaining fringe cases. + */ + + //R(a[0][0]) I(a[0][0]) R(a[1][0]) I(a[1][0]) + ymm0 = _mm256_loadu_pd((double const *)(temp_a)); + //R(a[2][0]) I(a[2][0]) R(a[3][0]) I(a[3][0]) + ymm1 = _mm256_loadu_pd((double const *)(temp_a + 2)); + + ymm13 = ymm0; + ymm14 = ymm1; + SCALE_ALPHA_REAL_M_LOOP(ymm0,ymm1,ymm15,alpha_real); + SCALE_ALPHA_IMAG_M_LOOP(ymm0,ymm1,ymm13,ymm14,ymm15,ymm2,alpha_imag); + + /* + The result after scaling with alpha_real and/or alpha_imag is as follows: + For ymm0 : + R(a[0][0]) = alpha_real*R(a[0][0])-alpha_imag*I(a[0][0]) + I(a[0][0]) = alpha_real*I(a[0][0])+alpha_imag*R[0][0] + R(a[1][0]) = alpha_real*R(a[1][0])-alpha_imag*I(a[1][0]) + I(a[1][0]) = alpha_real*I(a[1][0])+alpha_imag*(R[1][0]) + + For ymm1 : + R(a[2][0]) = alpha_real*R(a[2][0])-alpha_imag*I(a[2][0]) + I(a[2][0]) = alpha_real*I(a[2][0])+alpha_imag*R[2][0] + R(a[3][0]) = alpha_real*R(a[3][0])-alpha_imag*I(a[3][0]) + I(a[3][0]) = alpha_real*I(a[3][0])+alpha_imag*(R[3][0]) + */ + + //Calculating using real part of complex number in B matrix + FMA_M_LOOP(ymm0,ymm1,ymm3,ymm4,ymm2,(double const *)(temp_b)); + + //Calculating using imaginary part of complex numbers in B matrix + //Shuffling ymm0 and ymm1 in accordance to the requirement + NEG_PERM_M_LOOP(ymm0,ymm1,ymm2); + FMA_M_LOOP(ymm0,ymm1,ymm3,ymm4,ymm2,(double const *)(temp_b)+1); + + /* + a. Perform beta*C using temp_c, beta_real, + where beta_real is not zero. + b. This loop operates with 4x6 block size + along n dimension for every Z_NR columns of temp_c where + computing all Z_MR rows of temp_c. + c. Accumulated alpha*A*B into registers will be added to beta*C + d. Same approach is used in remaining fringe cases. + */ + if(beta_real != 0.0) + { + ymm15 = _mm256_broadcast_sd((double const *)(&beta_real)); + + //R(c[0][0]) I(c[0][0]) R(c[1][0]) I(c[1][0]) + ymm0 = _mm256_loadu_pd((double const *)(temp_c)); + //R(c[2][0]) I(c[2][0]) R(c[3][0]) I(c[3][0]) + ymm1 = _mm256_loadu_pd((double const *)(temp_c + 2)); + //ymm3+=beta_real*R(c[0][0]) beta_real*I(c[0][0]) + // beta_real*R(c[1][0]) beta_real*I(c[1][0]) + //ymm4+=beta_real*R(c[2][0]) beta_real*I(c[2][0]) + // beta_real*R(c[3][0]) beta_real*I(c[3][0]) + SCALE_BETA_REAL_M_LOOP(ymm0,ymm1,ymm3,ymm4,ymm15); + } + + /* + a. Perform beta*C using temp_c, beta_imag, + where beta_imag is not zero. + b. This loop operates with 4x6 block size + along n dimension for every Z_NR columns of temp_c where + computing all Z_MR rows of temp_c. + c. Accumulated alpha*A*B into registers will be added to beta*C + d. Same approach is used in remaining fringe cases. + */ + + if(beta_imag != 0.0) + { + ymm15 = _mm256_broadcast_sd((double const *)(&beta_imag)); + + ymm0 = _mm256_loadu_pd((double const *)(temp_c)); + ymm1 = _mm256_loadu_pd((double const *)(temp_c + 2)); + //ymm3+=beta_imag*(-I(c[0][0])) beta_imag*R(c[0][0]) + // beta_imag*(-I(c[1][0])) beta_imag*R(c[1][0]) + //ymm4+=beta_imag*(-I(c[2][0])) beta_imag*R(c[2][0]) + // beta_imag*(-I(c[3][0])) beta_imag*R(c[3][0]) + SCALE_BETA_IMAG_M_LOOP(ymm0,ymm1,ymm3,ymm4,ymm15,ymm2); + } + /* + The scaling has been done sequentially as follows: + - If alpha_real is not 0, it is used for scaling A + - If alpha_imag is not 0, it is used for scaling A using permutation + and selective negation, after loading + - If beta_real is not 0, is is used for scaling C + - If beta_imag is not 0, it is used for scaling C using permutation + and selective negation, after loading + + The results are accumalated in accordance to the non zero scalar values, + and similar approach is followed in fringe cases + */ + + //R(c[0][0]) I(c[0][0]) R(c[1][0]) I(c[1][0]) + _mm256_storeu_pd((double *)(temp_c), ymm3); + //R(c[2][0]) I(c[2][0]) R(c[3][0]) I(c[3][0]) + _mm256_storeu_pd((double *)(temp_c + 2), ymm4); + + temp_c+=Z_MR; + temp_a+=Z_MR; + } + + // Fringe cases for M + dim_t m_rem=m_remainder; + if(m_rem>=2) + { + ymm3 = _mm256_setzero_pd(); + + + //R(a[0][0]) I(a[0][0]) R(a[1][0]) I(a[1][0]) + ymm0 = _mm256_loadu_pd((double const *)(temp_a)); + + ymm13 = ymm0; + SCALE_ALPHA_REAL_M_FRINGE(ymm0,ymm15,alpha_real); + SCALE_ALPHA_IMAG_M_FRINGE(ymm0,ymm13,ymm15,ymm2,alpha_imag); + + /* + The result after scaling with alpha_real and/or alpha_imag is as follows: + For ymm0 : + R(a[0][0]) = alpha_real*R(a[0][0])-alpha_imag*I(a[0][0]) + I(a[0][0]) = alpha_real*I(a[0][0])+alpha_imag*R[0][0] + R(a[1][0]) = alpha_real*R(a[1][0])-alpha_imag*I(a[1][0]) + I(a[1][0]) = alpha_real*I(a[1][0])+alpha_imag*(R[1][0]) + */ + + //Calculating using real part of complex number in B matrix + //ymm3+=R(b[0][0])*R(a[0][0]) R(b[0][0])*I(a[0][0]) + // R(b[0][0])*R(a[1][0]) R(b[0][0])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm3,ymm2,(double const *)(temp_b)); + + //Calculating using imaginary part of complex numbers in B matrix + //Shuffling ymm0 in accordance to the requirement + NEG_PERM_M_FRINGE(ymm0,ymm2); + + // ymm3+=I(b[0][0])*R(a[0][0]) I(b[0][0])*I(a[0][0]) + // I(b[0][0])*R(a[1][0]) I(b[0][0])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm3,ymm2,(double const *)(temp_b)+1); + + + if(beta_real != 0.0) + { + ymm15 = _mm256_broadcast_sd((double const *)(&beta_real)); + + ymm0 = _mm256_loadu_pd((double const *)(temp_c)); + //ymm3+=beta_real*R(c[0][0]) beta_real*I(c[0][0]) + // beta_real*R(c[1][0]) beta_real*I(c[1][0]) + SCALE_BETA_REAL_M_FRINGE(ymm0,ymm3,ymm15); + } + + if(beta_imag != 0.0) + { + ymm15 = _mm256_broadcast_sd((double const *)(&beta_imag)); + + ymm0 = _mm256_loadu_pd((double const *)(temp_c)); + //ymm3+=beta_imag*(-I(c[0][0])) beta_imag*R(c[0][0]) + // beta_imag*(-I(c[1][0])) beta_imag*R(c[1][0]) + SCALE_BETA_IMAG_M_FRINGE(ymm0,ymm3,ymm15,ymm2); + } + /* + The scaling has been done sequentially as follows: + - If alpha_real is not 0, it is used for scaling A + - If alpha_imag is not 0, it is used for scaling A using permutation + and selective negation, after loading + - If beta_real is not 0, is is used for scaling C + - If beta_imag is not 0, it is used for scaling C using permutation + and selective negation, after loading + + The results are accumalated in accordance to the non zero scalar values, + and similar approach is followed in fringe cases + */ + + _mm256_storeu_pd((double *)(temp_c), ymm3); + + temp_c+=2; + temp_a+=2; + + m_rem -= 2; + } + + if(m_rem==1) + { + + xmm5 = _mm_setzero_pd(); + ymm3 = _mm256_setzero_pd(); + + xmm5 = _mm_loadu_pd((double const*)(temp_a)); + ymm0 = _mm256_insertf128_pd(ymm0,xmm5,0); + + ymm13 = ymm0; + SCALE_ALPHA_REAL_M_FRINGE(ymm0,ymm15,alpha_real); + SCALE_ALPHA_IMAG_M_FRINGE(ymm0,ymm13,ymm15,ymm2,alpha_imag); + + //Calculating using real part of complex number in B matrix + //ymm3+=R(b[0][0])*R(a[0][0]) R(b[0][0])*I(a[0][0]) + // R(b[0][0])*R(a[1][0]) R(b[0][0])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm3,ymm2,(double const *)(temp_b)); + + //Calculating using imaginary part of complex numbers in B matrix + //Shuffling ymm0 in accordance to the requirement + NEG_PERM_M_FRINGE(ymm0,ymm2); + + // ymm3+=I(b[0][0])*R(a[0][0]) I(b[0][0])*I(a[0][0]) + // I(b[0][0])*R(a[1][0]) I(b[0][0])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm3,ymm2,(double const *)(temp_b)+1); + + if(beta_real != 0.0) + { + ymm15 = _mm256_broadcast_sd((double const *)(&beta_real)); + + xmm5 = _mm_loadu_pd((double const*)(temp_c)); + ymm0 = _mm256_insertf128_pd(ymm0,xmm5,0); + //ymm3+=beta_real*R(c[0][0]) beta_real*I(c[0][0]) + SCALE_BETA_REAL_M_FRINGE(ymm0,ymm3,ymm15); + } + + if(beta_imag != 0.0) + { + ymm15 = _mm256_broadcast_sd((double const *)(&beta_imag)); + + xmm5 = _mm_loadu_pd((double const*)(temp_c)); + ymm0 = _mm256_insertf128_pd(ymm0,xmm5,0); + //ymm3+=beta_imag*(-I(c[0][0])) beta_imag*R(c[0][0]) + SCALE_BETA_IMAG_M_FRINGE(ymm0,ymm3,ymm15,ymm2); + } + + xmm5 = _mm256_extractf128_pd(ymm3, 0); + _mm_storeu_pd((double *)(temp_c), xmm5); + + } + + } + +} \ No newline at end of file diff --git a/kernels/zen/3/sup/bli_gemmsup_rd_zen_asm_s6x16.c b/kernels/zen/3/sup/bli_gemmsup_rd_zen_asm_s6x16.c index 96bc927499..c309c8c0cd 100644 --- a/kernels/zen/3/sup/bli_gemmsup_rd_zen_asm_s6x16.c +++ b/kernels/zen/3/sup/bli_gemmsup_rd_zen_asm_s6x16.c @@ -3,7 +3,7 @@ An object-based framework for developing high-performance BLAS-like libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020, Advanced Micro Devices, Inc. + Copyright (C) 2020 - 2022 , Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: @@ -516,7 +516,8 @@ void bli_sgemmsup_rd_zen_asm_1x16 je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case label(.SROWSTORED) - vfmadd231ps(mem(rcx), ymm3, ymm4) + vmovups(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm4) vmovups(xmm4, mem(rcx)) jmp(.SDONE) // jump to end. diff --git a/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_s6x16.c b/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_s6x16.c index 347384aa65..7befbb69bb 100644 --- a/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_s6x16.c +++ b/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_s6x16.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020, Advanced Micro Devices, Inc. + Copyright (C) 2020-2022, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -113,17 +113,19 @@ void bli_sgemmsup_rv_zen_asm_5x16 begin_asm() vxorps(ymm4, ymm4, ymm4) - vxorps(ymm5, ymm5, ymm5) - vxorps(ymm6, ymm6, ymm6) - vxorps(ymm7, ymm7, ymm7) - vxorps(ymm8, ymm8, ymm8) - vxorps(ymm9, ymm9, ymm9) - vxorps(ymm10, ymm10, ymm10) - vxorps(ymm11, ymm11, ymm11) - vxorps(ymm12, ymm12, ymm12) - vxorps(ymm13, ymm13, ymm13) - vxorps(ymm14, ymm14, ymm14) - vxorps(ymm15, ymm15, ymm15) + vmovaps(ymm4, ymm5) + vmovaps(ymm4, ymm6) + vmovaps(ymm4, ymm7) + vmovaps(ymm4, ymm8) + vmovaps(ymm4, ymm9) + vmovaps(ymm4, ymm10) + vmovaps(ymm4, ymm11) + vmovaps(ymm4, ymm12) + vmovaps(ymm4, ymm13) + vmovaps(ymm4, ymm14) + vmovaps(ymm4, ymm15) + + mov(var(a), rax) // load address of a. mov(var(rs_a), r8) // load rs_a mov(var(cs_a), r9) // load cs_a @@ -694,6 +696,7 @@ void bli_sgemmsup_rv_zen_asm_5x16 vmovss(xmm14, mem(rdx, rax, 1)) label(.SDONE) + vzeroupper() end_asm( : // output operands (none) @@ -758,19 +761,20 @@ void bli_sgemmsup_rv_zen_asm_4x16 // ------------------------------------------------------------------------- begin_asm() - - vxorps(ymm4, ymm4, ymm4) - vxorps(ymm5, ymm5, ymm5) - vxorps(ymm6, ymm6, ymm6) - vxorps(ymm7, ymm7, ymm7) - vxorps(ymm8, ymm8, ymm8) - vxorps(ymm9, ymm9, ymm9) - vxorps(ymm10, ymm10, ymm10) - vxorps(ymm11, ymm11, ymm11) - vxorps(ymm12, ymm12, ymm12) - vxorps(ymm13, ymm13, ymm13) - vxorps(ymm14, ymm14, ymm14) - vxorps(ymm15, ymm15, ymm15) + + vxorps(ymm4, ymm4, ymm4) + vmovaps(ymm4, ymm5) + vmovaps(ymm4, ymm6) + vmovaps(ymm4, ymm7) + vmovaps(ymm4, ymm8) + vmovaps(ymm4, ymm9) + vmovaps(ymm4, ymm10) + vmovaps(ymm4, ymm11) + vmovaps(ymm4, ymm12) + vmovaps(ymm4, ymm13) + vmovaps(ymm4, ymm14) + vmovaps(ymm4, ymm15) + mov(var(a), rax) // load address of a. mov(var(rs_a), r8) // load rs_a mov(var(cs_a), r9) // load cs_a @@ -822,14 +826,14 @@ void bli_sgemmsup_rv_zen_asm_4x16 prefetch(0, mem(rdx, rsi, 2, 3*8)) // prefetch c + 7*cs_c label(.SPOSTPFETCH) // done prefetching c - + mov(var(k_iter), rsi) // i = k_iter; test(rsi, rsi) // check i via logical AND. je(.SCONSIDKLEFT) // if i == 0, jump to code that // contains the k_left loop. - + label(.SLOOPKITER) // MAIN LOOP - + // ---------------------------------- iteration 0 vmovups(mem(rbx, 0*32), ymm0) vmovups(mem(rbx, 1*32), ymm1) @@ -1188,7 +1192,8 @@ void bli_sgemmsup_rv_zen_asm_4x16 vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma07..gamma37 ) label(.SDONE) - + vzeroupper() + end_asm( : // output operands (none) : // input operands @@ -1252,19 +1257,20 @@ void bli_sgemmsup_rv_zen_asm_3x16 // ------------------------------------------------------------------------- begin_asm() - - vxorps(ymm4, ymm4, ymm4) - vxorps(ymm5, ymm5, ymm5) - vxorps(ymm6, ymm6, ymm6) - vxorps(ymm7, ymm7, ymm7) - vxorps(ymm8, ymm8, ymm8) - vxorps(ymm9, ymm9, ymm9) - vxorps(ymm10, ymm10, ymm10) - vxorps(ymm11, ymm11, ymm11) - vxorps(ymm12, ymm12, ymm12) - vxorps(ymm13, ymm13, ymm13) - vxorps(ymm14, ymm14, ymm14) - vxorps(ymm15, ymm15, ymm15) + + vxorps(ymm4, ymm4, ymm4) + vmovaps(ymm4, ymm5) + vmovaps(ymm4, ymm6) + vmovaps(ymm4, ymm7) + vmovaps(ymm4, ymm8) + vmovaps(ymm4, ymm9) + vmovaps(ymm4, ymm10) + vmovaps(ymm4, ymm11) + vmovaps(ymm4, ymm12) + vmovaps(ymm4, ymm13) + vmovaps(ymm4, ymm14) + vmovaps(ymm4, ymm15) + mov(var(a), rax) // load address of a. mov(var(rs_a), r8) // load rs_a mov(var(cs_a), r9) // load cs_a @@ -1746,6 +1752,7 @@ void bli_sgemmsup_rv_zen_asm_3x16 vmovss(xmm14, mem(rdx, rax, 1)) label(.SDONE) + vzeroupper() end_asm( : // output operands (none) @@ -6846,7 +6853,7 @@ void bli_sgemmsup_rv_zen_asm_6x2 // ---------------------------------- iteration 0 - vmovups(mem(rbx, 0*32), xmm0) + vmovsd(mem(rbx, 0*32), xmm0) add(r10, rbx) // b += rs_b; vbroadcastss(mem(rax ), xmm2) @@ -6867,7 +6874,7 @@ void bli_sgemmsup_rv_zen_asm_6x2 // ---------------------------------- iteration 1 - vmovups(mem(rbx, 0*32), xmm0) + vmovsd(mem(rbx, 0*32), xmm0) add(r10, rbx) // b += rs_b; vbroadcastss(mem(rax ), xmm2) @@ -6888,7 +6895,7 @@ void bli_sgemmsup_rv_zen_asm_6x2 // ---------------------------------- iteration 2 - vmovups(mem(rbx, 0*32), xmm0) + vmovsd(mem(rbx, 0*32), xmm0) add(r10, rbx) // b += rs_b; vbroadcastss(mem(rax ), xmm2) @@ -6909,7 +6916,7 @@ void bli_sgemmsup_rv_zen_asm_6x2 // ---------------------------------- iteration 3 - vmovups(mem(rbx, 0*32), xmm0) + vmovsd(mem(rbx, 0*32), xmm0) add(r10, rbx) // b += rs_b; vbroadcastss(mem(rax ), xmm2) @@ -6942,7 +6949,7 @@ void bli_sgemmsup_rv_zen_asm_6x2 label(.SLOOPKLEFT) // EDGE LOOP - vmovups(mem(rbx, 0*32), xmm0) + vmovsd(mem(rbx, 0*32), xmm0) add(r10, rbx) // b += rs_b; vbroadcastss(mem(rax ), xmm2) @@ -7230,7 +7237,7 @@ void bli_sgemmsup_rv_zen_asm_5x2 label(.SLOOPKITER) // MAIN LOOP // ---------------------------------- iteration 0 - vmovups(mem(rbx, 0*32), xmm0) + vmovsd(mem(rbx, 0*32), xmm0) add(r10, rbx) // b += rs_b; vbroadcastss(mem(rax ), xmm2) @@ -7249,7 +7256,7 @@ void bli_sgemmsup_rv_zen_asm_5x2 // ---------------------------------- iteration 1 - vmovups(mem(rbx, 0*32), xmm0) + vmovsd(mem(rbx, 0*32), xmm0) add(r10, rbx) // b += rs_b; vbroadcastss(mem(rax ), xmm2) @@ -7268,7 +7275,7 @@ void bli_sgemmsup_rv_zen_asm_5x2 // ---------------------------------- iteration 2 - vmovups(mem(rbx, 0*32), xmm0) + vmovsd(mem(rbx, 0*32), xmm0) add(r10, rbx) // b += rs_b; vbroadcastss(mem(rax ), xmm2) @@ -7287,7 +7294,7 @@ void bli_sgemmsup_rv_zen_asm_5x2 // ---------------------------------- iteration 3 - vmovups(mem(rbx, 0*32), xmm0) + vmovsd(mem(rbx, 0*32), xmm0) add(r10, rbx) // b += rs_b; vbroadcastss(mem(rax ), xmm2) @@ -7318,7 +7325,7 @@ void bli_sgemmsup_rv_zen_asm_5x2 label(.SLOOPKLEFT) // EDGE LOOP - vmovups(mem(rbx, 0*32), xmm0) + vmovsd(mem(rbx, 0*32), xmm0) add(r10, rbx) // b += rs_b; vbroadcastss(mem(rax ), xmm2) @@ -7605,7 +7612,7 @@ void bli_sgemmsup_rv_zen_asm_4x2 label(.SLOOPKITER) // MAIN LOOP // ---------------------------------- iteration 0 - vmovups(mem(rbx, 0*32), xmm0) + vmovsd(mem(rbx, 0*32), xmm0) add(r10, rbx) // b += rs_b; vbroadcastss(mem(rax ), ymm2) @@ -7620,7 +7627,7 @@ void bli_sgemmsup_rv_zen_asm_4x2 vfmadd231ps(xmm0, xmm3, xmm10) // ---------------------------------- iteration 1 - vmovups(mem(rbx, 0*32), xmm0) + vmovsd(mem(rbx, 0*32), xmm0) add(r10, rbx) // b += rs_b; vbroadcastss(mem(rax ), ymm2) @@ -7635,7 +7642,7 @@ void bli_sgemmsup_rv_zen_asm_4x2 vfmadd231ps(xmm0, xmm3, xmm10) // ---------------------------------- iteration 2 - vmovups(mem(rbx, 0*32), xmm0) + vmovsd(mem(rbx, 0*32), xmm0) add(r10, rbx) // b += rs_b; vbroadcastss(mem(rax ), ymm2) @@ -7650,7 +7657,7 @@ void bli_sgemmsup_rv_zen_asm_4x2 vfmadd231ps(xmm0, xmm3, xmm10) // ---------------------------------- iteration 3 - vmovups(mem(rbx, 0*32), xmm0) + vmovsd(mem(rbx, 0*32), xmm0) add(r10, rbx) // b += rs_b; vbroadcastss(mem(rax ), ymm2) @@ -7678,7 +7685,7 @@ void bli_sgemmsup_rv_zen_asm_4x2 label(.SLOOPKLEFT) // EDGE LOOP - vmovups(mem(rbx, 0*32), xmm0) + vmovsd(mem(rbx, 0*32), xmm0) add(r10, rbx) // b += rs_b; vbroadcastss(mem(rax ), ymm2) @@ -7933,7 +7940,7 @@ void bli_sgemmsup_rv_zen_asm_3x2 label(.SLOOPKITER) // MAIN LOOP // ---------------------------------- iteration 0 - vmovups(mem(rbx, 0*32), xmm0) + vmovsd(mem(rbx, 0*32), xmm0) add(r10, rbx) // b += rs_b; vbroadcastss(mem(rax ), ymm2) @@ -7946,7 +7953,7 @@ void bli_sgemmsup_rv_zen_asm_3x2 vfmadd231ps(xmm0, xmm2, xmm8) // ---------------------------------- iteration 1 - vmovups(mem(rbx, 0*32), xmm0) + vmovsd(mem(rbx, 0*32), xmm0) add(r10, rbx) // b += rs_b; vbroadcastss(mem(rax ), ymm2) @@ -7959,7 +7966,7 @@ void bli_sgemmsup_rv_zen_asm_3x2 vfmadd231ps(xmm0, xmm2, xmm8) // ---------------------------------- iteration 2 - vmovups(mem(rbx, 0*32), xmm0) + vmovsd(mem(rbx, 0*32), xmm0) add(r10, rbx) // b += rs_b; vbroadcastss(mem(rax ), ymm2) @@ -7972,7 +7979,7 @@ void bli_sgemmsup_rv_zen_asm_3x2 vfmadd231ps(xmm0, xmm2, xmm8) // ---------------------------------- iteration 3 - vmovups(mem(rbx, 0*32), xmm0) + vmovsd(mem(rbx, 0*32), xmm0) add(r10, rbx) // b += rs_b; vbroadcastss(mem(rax ), ymm2) @@ -7998,7 +8005,7 @@ void bli_sgemmsup_rv_zen_asm_3x2 label(.SLOOPKLEFT) // EDGE LOOP - vmovups(mem(rbx, 0*32), xmm0) + vmovsd(mem(rbx, 0*32), xmm0) add(r10, rbx) // b += rs_b; vbroadcastss(mem(rax ), ymm2) @@ -8041,15 +8048,18 @@ void bli_sgemmsup_rv_zen_asm_3x2 label(.SROWSTORED) - vfmadd231ps(mem(rcx), xmm3, xmm4) + vmovsd(mem(rcx), xmm0)////a0a1 + vfmadd231ps(xmm0, xmm3, xmm4) vmovsd(xmm4, mem(rcx)) add(rdi, rcx) - vfmadd231ps(mem(rcx), xmm3, xmm6) + vmovsd(mem(rcx), xmm0)////a0a1 + vfmadd231ps(xmm0, xmm3, xmm6) vmovsd(xmm6, mem(rcx)) add(rdi, rcx) - vfmadd231ps(mem(rcx), xmm3, xmm8) + vmovsd(mem(rcx), xmm0)////a0a1 + vfmadd231ps(xmm0, xmm3, xmm8) vmovsd(xmm8, mem(rcx)) jmp(.SDONE) // jump to end. @@ -8235,7 +8245,7 @@ void bli_sgemmsup_rv_zen_asm_2x2 // ---------------------------------- iteration 0 - vmovups(mem(rbx, 0*32), xmm0) + vmovsd(mem(rbx, 0*32), xmm0) add(r10, rbx) // b += rs_b; vbroadcastss(mem(rax ), ymm2) vbroadcastss(mem(rax, r8, 1), ymm3) @@ -8244,7 +8254,7 @@ void bli_sgemmsup_rv_zen_asm_2x2 vfmadd231ps(xmm0, xmm3, xmm6) // ---------------------------------- iteration 1 - vmovups(mem(rbx, 0*32), xmm0) + vmovsd(mem(rbx, 0*32), xmm0) add(r10, rbx) // b += rs_b; vbroadcastss(mem(rax ), ymm2) vbroadcastss(mem(rax, r8, 1), ymm3) @@ -8253,7 +8263,7 @@ void bli_sgemmsup_rv_zen_asm_2x2 vfmadd231ps(xmm0, xmm3, xmm6) // ---------------------------------- iteration 2 - vmovups(mem(rbx, 0*32), xmm0) + vmovsd(mem(rbx, 0*32), xmm0) add(r10, rbx) // b += rs_b; vbroadcastss(mem(rax ), ymm2) @@ -8263,7 +8273,7 @@ void bli_sgemmsup_rv_zen_asm_2x2 vfmadd231ps(xmm0, xmm3, xmm6) // ---------------------------------- iteration 3 - vmovups(mem(rbx, 0*32), xmm0) + vmovsd(mem(rbx, 0*32), xmm0) add(r10, rbx) // b += rs_b; vbroadcastss(mem(rax ), ymm2) @@ -8284,7 +8294,7 @@ void bli_sgemmsup_rv_zen_asm_2x2 label(.SLOOPKLEFT) // EDGE LOOP - vmovups(mem(rbx, 0*32), xmm0) + vmovsd(mem(rbx, 0*32), xmm0) add(r10, rbx) // b += rs_b; vbroadcastss(mem(rax ), ymm2) @@ -8322,11 +8332,13 @@ void bli_sgemmsup_rv_zen_asm_2x2 label(.SROWSTORED) - vfmadd231ps(mem(rcx), xmm3, xmm4) + vmovsd(mem(rcx), xmm0)////a0a1 + vfmadd231ps(xmm0, xmm3, xmm4) vmovsd(xmm4, mem(rcx)) add(rdi, rcx) - vfmadd231ps(mem(rcx), xmm3, xmm6) + vmovsd(mem(rcx), xmm0)////a0a1 + vfmadd231ps(xmm0, xmm3, xmm6) vmovsd(xmm6, mem(rcx)) jmp(.SDONE) // jump to end. @@ -8491,7 +8503,7 @@ void bli_sgemmsup_rv_zen_asm_1x2 label(.SLOOPKITER) // MAIN LOOP // ---------------------------------- iteration 0 - vmovups(mem(rbx, 0*32), xmm0) + vmovsd(mem(rbx, 0*32), xmm0) add(r10, rbx) // b += rs_b; vbroadcastss(mem(rax ), ymm2) @@ -8499,7 +8511,7 @@ void bli_sgemmsup_rv_zen_asm_1x2 vfmadd231ps(xmm0, xmm2, xmm4) // ---------------------------------- iteration 1 - vmovups(mem(rbx, 0*32), xmm0) + vmovsd(mem(rbx, 0*32), xmm0) add(r10, rbx) // b += rs_b; vbroadcastss(mem(rax ), ymm2) @@ -8508,7 +8520,7 @@ void bli_sgemmsup_rv_zen_asm_1x2 // ---------------------------------- iteration 2 - vmovups(mem(rbx, 0*32), xmm0) + vmovsd(mem(rbx, 0*32), xmm0) add(r10, rbx) // b += rs_b; vbroadcastss(mem(rax ), ymm2) @@ -8516,7 +8528,7 @@ void bli_sgemmsup_rv_zen_asm_1x2 vfmadd231ps(xmm0, xmm2, xmm4) // ---------------------------------- iteration 3 - vmovups(mem(rbx, 0*32), xmm0) + vmovsd(mem(rbx, 0*32), xmm0) add(r10, rbx) // b += rs_b; vbroadcastss(mem(rax ), ymm2) @@ -8536,7 +8548,7 @@ void bli_sgemmsup_rv_zen_asm_1x2 label(.SLOOPKLEFT) // EDGE LOOP - vmovups(mem(rbx, 0*32), xmm0) + vmovsd(mem(rbx, 0*32), xmm0) add(r10, rbx) // b += rs_b; vbroadcastss(mem(rax ), xmm2) @@ -8570,7 +8582,8 @@ void bli_sgemmsup_rv_zen_asm_1x2 label(.SROWSTORED) - vfmadd231ps(mem(rcx), xmm3, xmm4) + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm4) vmovsd(xmm4, mem(rcx)) jmp(.SDONE) // jump to end. diff --git a/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_s6x16m.c b/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_s6x16m.c index 41dbbd699e..d5e2135a66 100644 --- a/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_s6x16m.c +++ b/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_s6x16m.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020, Advanced Micro Devices, Inc. + Copyright (C) 2020-2022, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -1291,13 +1291,17 @@ void bli_sgemmsup_rv_zen_asm_6x8m vextractf128(imm(0x1), ymm0, xmm2) vpermilps(imm(0xe),xmm0,xmm5) vpermilps(imm(0xe),xmm2,xmm6) - vfmadd231ps(mem(rdx), xmm3, xmm0) - vfmadd231ps(mem(rdx, rsi, 4), xmm3, xmm2) + vmovq(mem(rdx),xmm4) + vmovq(mem(rdx, rsi, 4),xmm1) + vfmadd231ps(xmm4, xmm3, xmm0) + vfmadd231ps(xmm1, xmm3, xmm2) vmovlpd(xmm0, mem(rdx)) // store ( gamma40..gamma50 ) vmovlpd(xmm2, mem(rdx, rsi, 4)) // store ( gamma44..gamma54 ) lea(mem(rdx, rsi, 1), rdx) - vfmadd231ps(mem(rdx), xmm3, xmm5) - vfmadd231ps(mem(rdx, rsi, 4), xmm3, xmm6) + vmovq(mem(rdx),xmm4) + vmovq(mem(rdx, rsi, 4),xmm1) + vfmadd231ps(xmm4, xmm3, xmm5) + vfmadd231ps(xmm1, xmm3, xmm6) vmovlpd(xmm5, mem(rdx)) // store ( gamma41..gamma51 ) vmovlpd(xmm6, mem(rdx, rsi, 4)) // store ( gamma45..gamma55 ) lea(mem(rdx, rsi, 1), rdx) @@ -1306,13 +1310,17 @@ void bli_sgemmsup_rv_zen_asm_6x8m vextractf128(imm(0x1), ymm0, xmm2) vpermilps(imm(0xe),xmm0,xmm5) vpermilps(imm(0xe),xmm2,xmm6) - vfmadd231ps(mem(rdx), xmm3, xmm0) - vfmadd231ps(mem(rdx, rsi, 4), xmm3, xmm2) + vmovq(mem(rdx),xmm4) + vmovq(mem(rdx, rsi, 4),xmm1) + vfmadd231ps(xmm4, xmm3, xmm0) + vfmadd231ps(xmm1, xmm3, xmm2) vmovlpd(xmm0, mem(rdx)) // store ( gamma42..gamma52 ) vmovlpd(xmm2, mem(rdx, rsi, 4)) // store ( gamma46..gamma56 ) lea(mem(rdx, rsi, 1), rdx) - vfmadd231ps(mem(rdx), xmm3, xmm5) - vfmadd231ps(mem(rdx, rsi, 4), xmm3, xmm6) + vmovq(mem(rdx),xmm4) + vmovq(mem(rdx, rsi, 4),xmm1) + vfmadd231ps(xmm4, xmm3, xmm5) + vfmadd231ps(xmm1, xmm3, xmm6) vmovlpd(xmm5, mem(rdx)) // store ( gamma43..gamma53 ) vmovlpd(xmm6, mem(rdx, rsi, 4)) // store ( gamma47..gamma57 ) @@ -1810,11 +1818,13 @@ void bli_sgemmsup_rv_zen_asm_6x4m lea(mem(rdx, rsi, 1), rdx) vunpckhps(xmm14, xmm12, xmm0) vpermilps(imm(0x4e), xmm0, xmm5) - vfmadd231ps(mem(rdx), xmm3, xmm0) + vmovq(mem(rdx),xmm4) + vfmadd231ps(xmm4, xmm3, xmm0) vmovlpd(xmm0, mem(rdx)) // store ( gamma42..gamma52 ) lea(mem(rdx, rsi, 1), rdx) - vfmadd231ps(mem(rdx), xmm3, xmm5) + vmovq(mem(rdx),xmm4) + vfmadd231ps(xmm4, xmm3, xmm5) vmovlpd(xmm5, mem(rdx)) // store ( gamma43..gamma53 ) jmp(.SDONE) // jump to end. @@ -2231,22 +2241,28 @@ void bli_sgemmsup_rv_zen_asm_6x2m label(.SROWSTORED) - vfmadd231ps(mem(rcx), xmm3, xmm4) + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm4) vmovlpd(xmm4, mem(rcx)) add(rdi, rcx) - vfmadd231ps(mem(rcx), xmm3, xmm6) + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm6) vmovlpd(xmm6, mem(rcx)) add(rdi, rcx) - vfmadd231ps(mem(rcx), xmm3, xmm8) + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm8) vmovlpd(xmm8, mem(rcx)) add(rdi, rcx) - vfmadd231ps(mem(rcx), xmm3, xmm10) + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm10) vmovlpd(xmm10, mem(rcx)) add(rdi, rcx) - vfmadd231ps(mem(rcx), xmm3, xmm12) + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm12) vmovlpd(xmm12, mem(rcx)) add(rdi, rcx) - vfmadd231ps(mem(rcx), xmm3, xmm14) + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm14) vmovlpd(xmm14, mem(rcx)) jmp(.SDONE) // jump to end. diff --git a/kernels/zen/bli_kernels_zen.h b/kernels/zen/bli_kernels_zen.h index f2edd993ce..f16aa5cc98 100644 --- a/kernels/zen/bli_kernels_zen.h +++ b/kernels/zen/bli_kernels_zen.h @@ -138,6 +138,10 @@ GEMV_KER_PROT( double, d, gemv_zen_ref_c ) GEMV_KER_PROT( scomplex, c, gemv_zen_int_4x4 ) GEMV_KER_PROT( dcomplex, z, gemv_zen_int_4x4 ) +// her (intrinsics) +HER_KER_PROT( dcomplex, z, her_zen_int_var1 ) +HER_KER_PROT( dcomplex, z, her_zen_int_var2 ) + // -- level-3 sup -------------------------------------------------------------- // semmsup_rv @@ -285,18 +289,6 @@ err_t bli_zgemm_small_At cntl_t* cntl ); -// gemm square matrix size friendly implementation -err_t bli_gemm_sqp - ( - obj_t* alpha, - obj_t* a, - obj_t* b, - obj_t* beta, - obj_t* c, - cntx_t* cntx, - cntl_t* cntl - ); - void bli_dgemm_ref_k1_nn ( dim_t m, @@ -309,6 +301,18 @@ void bli_dgemm_ref_k1_nn double* c, const inc_t ldc ); +void bli_zgemm_ref_k1_nn + ( + dim_t m, + dim_t n, + dim_t k, + dcomplex* alpha, + dcomplex* a, const inc_t lda, + dcomplex* b, const inc_t ldb, + dcomplex* beta, + dcomplex* c, const inc_t ldc + ); + err_t bli_trsm_small ( side_t side, @@ -364,13 +368,37 @@ bool bli_cntx_syrksup_thresh_is_met_zen cntx_t* cntx ); -#ifdef BLIS_ENABLE_FAST_MATH -void bli_dnorm2fv_unb_var1 +/* + * Check if the TRSM small path should be taken for this + * input and threads combination + */ +bool bli_cntx_trsm_small_thresh_is_met_zen + ( + obj_t* a, + dim_t m, + dim_t n + ); + +void bli_dnorm2fv_unb_var1_avx2 ( dim_t n, double* x, inc_t incx, double* norm, cntx_t* cntx ); -#endif +void bli_dznorm2fv_unb_var1_avx2 + ( + dim_t n, + dcomplex* x, inc_t incx, + double* norm, + cntx_t* cntx + ); +void bli_zdscalv_zen_int10 + ( + conj_t conjalpha, + dim_t n, + double* restrict alpha, + dcomplex* restrict x, inc_t incx, + cntx_t* restrict cntx + ); \ No newline at end of file diff --git a/kernels/zen/util/bli_thresh_funcs_zen.c b/kernels/zen/util/bli_thresh_funcs_zen.c index 2786f00e43..82e4936fe1 100644 --- a/kernels/zen/util/bli_thresh_funcs_zen.c +++ b/kernels/zen/util/bli_thresh_funcs_zen.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/kernels/zen4/1/CMakeLists.txt b/kernels/zen4/1/CMakeLists.txt new file mode 100644 index 0000000000..7bd499efb6 --- /dev/null +++ b/kernels/zen4/1/CMakeLists.txt @@ -0,0 +1,6 @@ +##Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved.## + +target_sources("${PROJECT_NAME}" + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/bli_amaxv_zen_int_avx512.c + ) diff --git a/kernels/zen4/1/bli_amaxv_zen_int_avx512.c b/kernels/zen4/1/bli_amaxv_zen_int_avx512.c new file mode 100644 index 0000000000..9e32f955a8 --- /dev/null +++ b/kernels/zen4/1/bli_amaxv_zen_int_avx512.c @@ -0,0 +1,970 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "immintrin.h" +#include "blis.h" +typedef union +{ + __m512d v; + double d[8] __attribute__((aligned(64))); +} v8df_t; + +/* Union data structure to access AVX registers + One 512-bit AVX register holds 16 SP elements. */ +typedef union +{ + __m512 v; + float f[16] __attribute__((aligned(64))); +} v16sf_t; + +/* Union data structure to access AVX registers + One 256-bit AVX register holds 8 SP elements. */ +typedef union +{ + __m256 v; + float f[8] __attribute__((aligned(64))); +} v8sf_t; + +typedef union +{ + __m128 v; + float f[4]; +} v4sf_t; + +/* Union data structure to access AVX registers + One 256-bit AVX register holds 4 DP elements. */ +typedef union +{ + __m256d v; + double d[4] __attribute__((aligned(64))); +}v4df_t; + +typedef union +{ + __m128d v; + double d[2]; +}v2dd_t; + +/* Convert the nan to -ve numbers decrementing with + the times the function is called to ensure that + bigger numbers are assigned for nan which showed + up first.*/ +#define REMOVE_NAN_512S(reg_512) \ + { \ + /*Sign is -0.f in IEEE754 is just signbit set, all others 0*/ \ + __m512 sign_mask = _mm512_set1_ps( -0.0f ); \ + \ + /* Numbers other than NAN will become 0. */ \ + __m512 vec_mask = _mm512_mul_ps( reg_512, sign_mask ); \ + \ + /* Typecast mask into int type no clock cycle is taken just to + * convince compiler. */ \ + __m512i int_mask_vec = _mm512_castps_si512( vec_mask ); \ + /* Extract the signbits and put it in a 16bit mask register. */ \ + __mmask16 vec_mask16 = _mm512_movepi32_mask( int_mask_vec ); \ + \ + /* Swap NAN with -ve number. */ \ + reg_512 = _mm512_mask_blend_ps( vec_mask16, _mm512_set1_ps( nan_repl ), reg_512 ); \ + nan_repl = nan_repl - 1; \ + } + +// return a mask which indicates either: +// - v1 > v2 +// - v1 is NaN and v2 is not +// assumes that idx(v1) > idx(v2) +// all "OQ" comparisons false if either operand NaN +#define CMP256( dt, v1, v2 ) \ + _mm256_or_p##dt( _mm256_cmp_p##dt( v1, v2, _CMP_GT_OQ ), /* v1 > v2 || */ \ + _mm256_andnot_p##dt( _mm256_cmp_p##dt( v2, v2, _CMP_UNORD_Q ), /* ( !isnan(v2) && */ \ + _mm256_cmp_p##dt( v1, v1, _CMP_UNORD_Q ) /* isnan(v1) ) */ \ + ) \ + ); + +// return a mask which indicates either: +// - v1 > v2 +// - v1 is NaN and v2 is not +// - v1 == v2 (maybe == NaN) and i1 < i2 +// all "OQ" comparisons false if either operand NaN +#define CMP128( dt, v1, v2, i1, i2 ) \ + _mm_or_p##dt( _mm_or_p##dt( _mm_cmp_p##dt( v1, v2, _CMP_GT_OQ ), /* ( v1 > v2 || */ \ + _mm_andnot_p##dt( _mm_cmp_p##dt( v2, v2, _CMP_UNORD_Q ), /* ( !isnan(v2) && */ \ + _mm_cmp_p##dt( v1, v1, _CMP_UNORD_Q ) /* isnan(v1) ) ) || */ \ + ) \ + ), \ + _mm_and_p##dt( _mm_or_p##dt( _mm_cmp_p##dt( v1, v2, _CMP_EQ_OQ ), /* ( ( v1 == v2 || */ \ + _mm_and_p##dt( _mm_cmp_p##dt( v1, v1, _CMP_UNORD_Q ), /* ( isnan(v1) && */ \ + _mm_cmp_p##dt( v2, v2, _CMP_UNORD_Q ) /* isnan(v2) ) ) && */ \ + ) \ + ), \ + _mm_cmp_p##dt( i1, i2, _CMP_LT_OQ ) /* i1 < i2 ) */ \ + ) \ + ); + +// ---------------------------------------------------------------------------- +void bli_samaxv_zen_int_avx512( + dim_t n, + float *restrict x, inc_t incx, + dim_t *restrict i_max, + cntx_t *restrict cntx) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_3) + // *minus_one = -1 + float *minus_one = PASTEMAC(s, m1); // bli_sm1() + // *zero_i = 0 + dim_t *zero_i = PASTEMAC(i, 0); // bli_i0() + + // Used to replace NAN in registers. This value is decremented each time + // remove NAN is applied so as to keep the NAN value replacements unique. + float nan_repl = -1.0; + + float fndMaxVal; // Max value will be stored in this + dim_t fndInd; // Max value's index will be stored in this + // Iterator for loops to keep continuity throughout the loops + dim_t i; + + /* If the vector length is zero, return early. This directly emulates + the behavior of netlib BLAS's i?amax() routines. */ + if (bli_zero_dim1(n)) + { + /* Set i_max to zero if dimension is 0, no need to compute */ + // Copy zero_i, that is 0 to i_max (i_max = 0) + PASTEMAC(i, copys) // bli_icopys + (*zero_i, *i_max); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3) + return; + } + + /* Initialize the index of the maximum absolute value to zero. */ + // Copy zero_i, that is 0 to fndInd (fndInd = 0) + PASTEMAC(i, copys) // bli_icopys + (*zero_i, fndInd); + + /* Initialize the maximum absolute value search candidate with + -1, which is guaranteed to be less than all values we will + compute. */ + // Copy minus_one to fndMaxVal real and imaginary. + PASTEMAC(s, copys) // bli_scopys + (*minus_one, fndMaxVal); + + // For non-unit strides, or very small vector lengths, compute with + // scalar code. + // n is less than the single vector length or non unit stride. + if (incx != 1 || n < 16) + { + for (i = 0; i < n; ++i) + { + // Call math.h fabsf to take absolute value of *(x +(i)*incx) + float absval = fabsf(*(x + (i)*incx)); + if (fndMaxVal < absval || (isnan(absval) && !isnan(fndMaxVal))) + { + // If max value is found, set the value and index + fndMaxVal = absval; + fndInd = i; + } + } + } + else + { + dim_t num_iter, num_remain; + dim_t num_vector_elements = 16; + /* Total Registers used is + * xmm0-xmm4 + * ymm5-ymm9 + * zmm10-zmm26 + * There are 6 free registers to use + */ + // zmm register 15x + v16sf_t x_vec_1, x_vec_2, x_vec_3, max_vec_1, max_vec_2, + max_vec_3, maxInd_vec_1, maxInd_vec_2, + maxInd_vec_3, index_vec_1, ind_vec_2, + ind_vec_3, inc_vec, mask, + abs_mask; + // ymm register 5x + v8sf_t max_vec_lo, max_vec_hi, + maxInd_vec_lo, maxInd_vec_hi, + mask_vec_lo; + // xmm register 5x + v4sf_t max_vec_lo_lo, max_vec_lo_hi, + maxInd_vec_lo_lo, maxInd_vec_lo_hi, + mask_vec_lo_lo; + // zmm register 1x + __m512i intMask; + // k register 3x + __mmask16 mask_vec_1, mask_vec_2, + mask_vec_3; + + // Number of iterations for main loop. + num_iter = n / num_vector_elements; + // Number of iterations remaining for residual non vector loop + num_remain = n % num_vector_elements; + // A number with signbit one and others 0 IEEE-754 + abs_mask.v = _mm512_set1_ps(-0.f); + // index_vector after loading max_vector with initial values. + index_vec_1.v = _mm512_setr_ps(16, 17, 18, 19, 20, 21, + 22, 23, 24, 25, 26, 27, + 28, 29, 30, 31); + // Broadcast 16. This is to increment the vector easily + inc_vec.v = _mm512_set1_ps(16); + // Load 16 float values from memory + max_vec_1.v = _mm512_loadu_ps(x); + // max_vector = abs(max_vector) + max_vec_1.v = _mm512_andnot_ps(abs_mask.v, max_vec_1.v); + // Remove nan and replace with -ve values + REMOVE_NAN_512S(max_vec_1.v); + + // Increment x vector as we have loaded 16 values + x += num_vector_elements; + // indexes for values present in max vector. + maxInd_vec_1.v = _mm512_setr_ps(0, 1, 2, 3, 4, 5, 6, 7, 8, + 9, 10, 11, 12, 13, 14, 15); + + dim_t i = 1; + for (; (i + 4) < num_iter; i += 5) + { + /* + Unrolled to process 5 at a time. It basically works + by taking a master max_vec_1 and a maxInd_vec_1 + holding indexes. Elements are taken from the RAM on a batch + of 5 (1 master max_vec_1 already exists to compare so + 6 elements). Now each 2 of them is compared with each other + and an intermediate result is obtained. This intermediate + result is again with each other and combined until we reach + one vector in max_vector and maxIndex_vector. + */ + + // Load the vector and subs NAN + // Load Value x values + x_vec_1.v = _mm512_loadu_ps(x); + // x_vec_1 = abs(x_vec_1) + x_vec_1.v = _mm512_andnot_ps(abs_mask.v, x_vec_1.v); + // Increment x vector as we have loaded 16 values + x += num_vector_elements; + // Remove nan and replace with -ve values + REMOVE_NAN_512S(x_vec_1.v); + + // Mask Generation of 1st(can be previous max) and 2nd element + // mask = max_vector - x_vec_1 + mask.v = _mm512_sub_ps(max_vec_1.v, x_vec_1.v); + // Type cast mask from IEEE754 (float) to integer type + // This operation will not need a new register, its just to convince + // the compiler. But its accounted as seperate register in the + // above calculations + intMask = _mm512_castps_si512(mask.v); + // Extract the signbit and build the mask. + mask_vec_1 = _mm512_movepi32_mask(intMask); + + // Load 2 elements to 2nd max and x vector, set indexes + // Load Value x values + max_vec_2.v = _mm512_loadu_ps(x); + // max_vec_2 = abs(max_vec_2) + max_vec_2.v = _mm512_andnot_ps(abs_mask.v, max_vec_2.v); + // Remove nan and replace with -ve values + REMOVE_NAN_512S(max_vec_2.v); + // Increment x vector as we have loaded 16 values + x += num_vector_elements; + // Increment the index vector to point to next indexes. + maxInd_vec_2.v = _mm512_add_ps(index_vec_1.v, inc_vec.v); + + // Load Value x values + x_vec_2.v = _mm512_loadu_ps(x); + // x_vec_2 = abs(x_vec_2) + x_vec_2.v = _mm512_andnot_ps(abs_mask.v, x_vec_2.v); + // Remove nan and replace with -ve values + REMOVE_NAN_512S(x_vec_2.v); + // Increment x vector as we have loaded 16 values + x += num_vector_elements; + // Increment the index vector to point to next indexes. + ind_vec_2.v = _mm512_add_ps(maxInd_vec_2.v, inc_vec.v); + + // Mask generation for last loaded 2 elements into x and max vectors. + // mask = max_vec_2 - x_vec_2 + mask.v = _mm512_sub_ps(max_vec_2.v, x_vec_2.v); + // Type cast mask from IEEE754 (float) to integer type + // This operation will not need a new register, its just to convince + // the compiler. But its accounted as seperate register in the + // above calculations + intMask = _mm512_castps_si512(mask.v); + // Extract the signbit and build the mask. + mask_vec_2 = _mm512_movepi32_mask(intMask); + + // Load 2 more elements to 3rd max and x vector, set indexes + // Load Value x values + max_vec_3.v = _mm512_loadu_ps(x); + // max_vec_3 = abs(max_vec_3) + max_vec_3.v = _mm512_andnot_ps(abs_mask.v, max_vec_3.v); + // Remove nan and replace with -ve values + REMOVE_NAN_512S(max_vec_3.v); + // Increment x vector as we have loaded 16 values + x += num_vector_elements; + // Increment the index vector to point to next indexes. + maxInd_vec_3.v = _mm512_add_ps(ind_vec_2.v, inc_vec.v); + // Load Value x values + x_vec_3.v = _mm512_loadu_ps(x); + // x_vec_3 = abs(x_vec_3) + x_vec_3.v = _mm512_andnot_ps(abs_mask.v, x_vec_3.v); + // Remove nan and replace with -ve values + REMOVE_NAN_512S(x_vec_3.v); + // Increment x vector as we have loaded 16 values + x += num_vector_elements; + // Increment the index vector to point to next indexes. + ind_vec_3.v = _mm512_add_ps(maxInd_vec_3.v, inc_vec.v); + + // Mask generation for last 2 elements loaded into x and max vectors. + // mask = max_vec_3 - x_vec_3 + mask.v = _mm512_sub_ps(max_vec_3.v, x_vec_3.v); + // Type cast mask from IEEE754 (float) to integer type + // This operation will not need a new register, its just to convince + // the compiler. But its accounted as seperate register in the + // above calculations + intMask = _mm512_castps_si512(mask.v); + // Extract the signbit and build the mask. + mask_vec_3 = _mm512_movepi32_mask(intMask); + + // Blend max vector and index vector (3 pairs of elements needs to be blended). + /* Take values from max_vector if corresponding bit in mask_vector is 0 + * otherwise take value from x_vector, this is accumulated maximum value + * from max_vector and x_vector to mask_vector */ + max_vec_1.v = _mm512_mask_blend_ps(mask_vec_1, + max_vec_1.v, + x_vec_1.v); + /* Take values from max_vector if corresponding bit in mask_vector is 0 + * otherwise take value from x_vector, this is accumulated maximum value + * from max_vector and x_vector to mask_vector */ + max_vec_2.v = _mm512_mask_blend_ps(mask_vec_2, + max_vec_2.v, + x_vec_2.v); + /* Take values from max_vector if corresponding bit in mask_vector is 0 + * otherwise take value from x_vector, this is accumulated maximum value + * from max_vector and x_vector to mask_vector */ + max_vec_3.v = _mm512_mask_blend_ps(mask_vec_3, + max_vec_3.v, + x_vec_3.v); + /* Take values from maxIndex_vector if corresponding bit in mask_vector + * is 0 otherwise take value from index_vec_1, this is accumulated + * maximum value index from maxIndex_vector and index_vec_1 + * to maxIndex_vector */ + maxInd_vec_1.v = _mm512_mask_blend_ps(mask_vec_1, + maxInd_vec_1.v, + index_vec_1.v); + /* Take values from maxIndex_vector if corresponding bit in mask_vector + * is 0 otherwise take value from index_vec_1, this is accumulated + * maximum value index from maxIndex_vector and index_vec_1 + * to maxIndex_vector */ + maxInd_vec_2.v = _mm512_mask_blend_ps(mask_vec_2, + maxInd_vec_2.v, + ind_vec_2.v); + /* Take values from maxIndex_vector if corresponding bit in mask_vector + * is 0 otherwise take value from index_vec_1, this is accumulated + * maximum value index from maxIndex_vector and index_vec_1 + * to maxIndex_vector */ + maxInd_vec_3.v = _mm512_mask_blend_ps(mask_vec_3, + maxInd_vec_3.v, + ind_vec_3.v); + + // Mask generation for blending max_vec_2 and max_vec_3 to max_vec_2. + // mask = max_vec_2 - max_vec_3 + mask.v = _mm512_sub_ps(max_vec_2.v, max_vec_3.v); + // Type cast mask from IEEE754 (float) to integer type + // This operation will not need a new register, its just to convince + // the compiler. But its accounted as seperate register in the + // above calculations + intMask = _mm512_castps_si512(mask.v); + // Extract the signbit and build the mask. + mask_vec_2 = _mm512_movepi32_mask(intMask); + + // Blend to obtain 1 vector each of max values and index. + /* Take values from max_vec_2 if corresponding bit in mask_vec_2 + * is 0 otherwise take value from max_vec_3, this is accumulated + * maximum value from max_vec_2 and max_vec_3 to mask_vec_2 */ + max_vec_2.v = _mm512_mask_blend_ps(mask_vec_2, + max_vec_2.v, + max_vec_3.v); + /* Take values from maxInd_vec_2 if corresponding bit in mask_vector + * is 0 otherwise take value from maxInd_vec_3, this is accumulated + * maximum value index from maxInd_vec_2 and maxInd_vec_3 + * to maxInd_vec_2 */ + maxInd_vec_2.v = _mm512_mask_blend_ps(mask_vec_2, + maxInd_vec_2.v, + maxInd_vec_3.v); + + // Mask generation for blending max_vec_1 and max_vec_2 into max_vec_1. + // mask = max_vec_1 - max_vec_2 + mask.v = _mm512_sub_ps(max_vec_1.v, max_vec_2.v); + // Type cast mask from IEEE754 (float) to integer type + // This operation will not need a new register, its just to convince + // the compiler. But its accounted as seperate register in the + // above calculations + intMask = _mm512_castps_si512(mask.v); + // Extract the signbit and build the mask. + mask_vec_1 = _mm512_movepi32_mask(intMask); + + // Final blend to the master max_vec_1 and maxInd_vec_1 + /* Take values from max_vec_1 if corresponding bit in mask_vec_1 + * is 0 otherwise take value from max_vec_2, this is accumulated + * maximum value from max_vec_1 and max_vec_2 to mask_vec_1 */ + max_vec_1.v = _mm512_mask_blend_ps(mask_vec_1, max_vec_1.v, max_vec_2.v); + /* Take values from maxInd_vec_1 if corresponding bit in mask_vector + * is 0 otherwise take value from maxInd_vec_2, this is accumulated + * maximum value index from maxInd_vec_1 and maxInd_vec_2 + * to maxInd_vec_1 */ + maxInd_vec_1.v = _mm512_mask_blend_ps(mask_vec_1, + maxInd_vec_1.v, + maxInd_vec_2.v); + + // Increment the index vector to point to next indexes. + index_vec_1.v = _mm512_add_ps(ind_vec_3.v, inc_vec.v); + } + + for (; i < num_iter; i++) + { + /* + Take vector one by one, above code makes max_vec_1 + contain the first 16 elements, now with the max vector + as first 16 elements (abs), we need to load next 16 elements + into x_vec_1 (abs). Now with those we can safely removeNan + which will put -ve values as NAN. + + These -ve values of NAN decreases by 1 in each iteration, + this helps us find the first NAN value. + */ + // Load Value x values + x_vec_1.v = _mm512_loadu_ps(x); + // x_vec_1 = abs(x_vec_1) + x_vec_1.v = _mm512_andnot_ps(abs_mask.v, x_vec_1.v); + // Remove nan and replace with -ve values + REMOVE_NAN_512S(x_vec_1.v); + + // Mask Generation + // mask = max_vec_1 - x_vec_1 + mask.v = _mm512_sub_ps(max_vec_1.v, x_vec_1.v); + // Extract the signbit and build the mask. + mask_vec_1 = _mm512_movepi32_mask(_mm512_castps_si512(mask.v)); + /* Take values from max_vec_1 if corresponding bit in + * mask_vec_1 is 0 otherwise take value from x_vec_1, + * this is accumulated maximum value from max_vec_1 and + * x_vec_1 to mask_vec_1 */ + max_vec_1.v = _mm512_mask_blend_ps(mask_vec_1, + max_vec_1.v, + x_vec_1.v); + /* Take values from maxInd_vec_1 if corresponding bit in + * mask_vector is 0 otherwise take value from index_vec_1, + * this is accumulated maximum value index from maxInd_vec_1 + * and index_vec_1 to maxInd_vec_1 */ + maxInd_vec_1.v = _mm512_mask_blend_ps(mask_vec_1, + maxInd_vec_1.v, + index_vec_1.v); + + // Increment the index vector to point to next indexes. + index_vec_1.v = _mm512_add_ps(index_vec_1.v, inc_vec.v); + + // Increment x vector as we have loaded 16 values + x += num_vector_elements; + } + + num_remain = (n - ((i)*16)); + + /* + Now take the max vector and produce the max value from + the max vector by slicing and comparing with itself, + until we are left with just one index position and max value. + */ + // Split max to hi and lo + max_vec_hi.v = _mm512_extractf32x8_ps(max_vec_1.v, 1); + max_vec_lo.v = _mm512_extractf32x8_ps(max_vec_1.v, 0); + + // Split maxIndex to hi and lo + maxInd_vec_hi.v = _mm512_extractf32x8_ps(maxInd_vec_1.v, 1); + maxInd_vec_lo.v = _mm512_extractf32x8_ps(maxInd_vec_1.v, 0); + + // Compare max_vec_hi > max_vec_1 + // mask_vec_lo = max_vec_lo - max_vec_hi + mask_vec_lo.v = _mm256_sub_ps(max_vec_lo.v, max_vec_hi.v); + + /* Take values from max_vec_lo if corresponding bit in mask_vec_lo + * is 0 otherwise take value from max_vec_hi, this is accumulated + * maximum value from max_vec_lo and max_vec_hi to max_vec_lo */ + max_vec_lo.v = _mm256_blendv_ps(max_vec_lo.v, + max_vec_hi.v, + mask_vec_lo.v); + /* Take values from maxInd_vec_lo if corresponding bit + * in mask_vec_lo is 0 otherwise take value from maxInd_vec_hi, + * this is accumulated maximum value from maxInd_vec_lo and + * maxInd_vec_hi to maxInd_vec_lo */ + maxInd_vec_lo.v = _mm256_blendv_ps(maxInd_vec_lo.v, + maxInd_vec_hi.v, + mask_vec_lo.v); + + // Split max_lo to hi and lo + max_vec_lo_hi.v = _mm256_extractf128_ps(max_vec_lo.v, 1); + max_vec_lo_lo.v = _mm256_extractf128_ps(max_vec_lo.v, 0); + + // Split maxIndex_lo to hi and lo + maxInd_vec_lo_hi.v = _mm256_extractf128_ps(maxInd_vec_lo.v, 1); + maxInd_vec_lo_lo.v = _mm256_extractf128_ps(maxInd_vec_lo.v, 0); + + // mask_vec_lo_lo = max_vec_lo_lo - max_vec_lo_hi + mask_vec_lo_lo.v = _mm_sub_ps(max_vec_lo_lo.v, max_vec_lo_hi.v); + /* Take values from max_vec_lo_lo if corresponding bit in + * mask_vec_lo_lo is 0 otherwise take value from max_vec_lo_hi, + * this is accumulated maximum value from max_vec_lo_lo and + * max_vec_lo_hi to max_vec_lo_lo */ + max_vec_lo_lo.v = _mm_blendv_ps(max_vec_lo_lo.v, + max_vec_lo_hi.v, + mask_vec_lo_lo.v); + /* Take values from maxInd_vec_lo if corresponding bit + * in mask_vec_lo_lo is 0 otherwise take value from maxInd_vec_hi, + * this is accumulated maximum value from maxInd_vec_lo and + * maxInd_vec_hi to maxInd_vec_lo */ + maxInd_vec_lo_lo.v = _mm_blendv_ps(maxInd_vec_lo_lo.v, + maxInd_vec_lo_hi.v, + mask_vec_lo_lo.v); + + // Take 64 high bits of max_lo_lo and put it to 64 low bits, rest 1st value + /* Example max_vec_lo_lo is {a, b, x, y} + * After max_vec_lo_hi.v = _mm_permute_ps(max_vec_lo_lo.v, 14); + * max_vec_lo_hi is {x, y, a, a} (essentially folding the vector) + */ + max_vec_lo_hi.v = _mm_permute_ps(max_vec_lo_lo.v, 14); + // Fold the vector same as max_vector + maxInd_vec_lo_hi.v = _mm_permute_ps(maxInd_vec_lo_lo.v, 14); + + // mask_vec_lo_lo = max_vec_lo_lo - max_vec_lo_hi + mask_vec_lo_lo.v = _mm_sub_ps(max_vec_lo_lo.v, max_vec_lo_hi.v); + /* Take values from max_vec_lo_lo if corresponding bit in + * mask_vec_lo_lo is 0 otherwise take value from max_vec_lo_hi, + * this is accumulated maximum value from max_vec_lo_lo and + * max_vec_lo_hi to max_vec_lo_lo */ + max_vec_lo_lo.v = _mm_blendv_ps(max_vec_lo_lo.v, + max_vec_lo_hi.v, + mask_vec_lo_lo.v); + /* Take values from maxInd_vec_lo if corresponding bit + * in mask_vec_lo_lo is 0 otherwise take value from maxInd_vec_hi, + * this is accumulated maximum value from maxInd_vec_lo and + * maxInd_vec_hi to maxInd_vec_lo */ + maxInd_vec_lo_lo.v = _mm_blendv_ps(maxInd_vec_lo_lo.v, + maxInd_vec_lo_hi.v, + mask_vec_lo_lo.v); + + // Take max_vec_lo_lo.f[1] and put it to max_vec_lo_hi.f[0] + /* Example max_vec_lo_lo is {a, b, x, y} + * After max_vec_lo_hi.v = _mm_permute_ps(max_vec_lo_lo.v, 1); + * max_vec_lo_hi is {b, a, a, a} (essentially folding the vector) + */ + max_vec_lo_hi.v = _mm_permute_ps(max_vec_lo_lo.v, 1); + // Do the same operation. + maxInd_vec_lo_hi.v = _mm_permute_ps(maxInd_vec_lo_lo.v, 1); + + // mask_vec_lo_lo = max_vec_lo_lo - max_vec_lo_hi + mask_vec_lo_lo.v = _mm_sub_ps(max_vec_lo_lo.v, max_vec_lo_hi.v); + /* Take values from max_vec_lo_lo if corresponding bit in + * mask_vec_lo_lo is 0 otherwise take value from max_vec_lo_hi, + * this is accumulated maximum value from max_vec_lo_lo and + * max_vec_lo_hi to max_vec_lo_lo */ + max_vec_lo_lo.v = _mm_blendv_ps(max_vec_lo_lo.v, + max_vec_lo_hi.v, + mask_vec_lo_lo.v); + /* Take values from maxInd_vec_lo if corresponding bit + * in mask_vec_lo_lo is 0 otherwise take value from maxInd_vec_hi, + * this is accumulated maximum value from maxInd_vec_lo and + * maxInd_vec_hi to maxInd_vec_lo */ + maxInd_vec_lo_lo.v = _mm_blendv_ps(maxInd_vec_lo_lo.v, + maxInd_vec_lo_hi.v, + mask_vec_lo_lo.v); + /* We have kept on folding and comparing until we got one single index + * and max value so that is the final answer so set it as the final + * answer.*/ + fndInd = maxInd_vec_lo_lo.f[0]; + fndMaxVal = max_vec_lo_lo.f[0]; + // Found value is < 0 means it was the max NAN which was accumulated. + if (fndMaxVal < 0) + { + // So just set it as NAN + fndMaxVal = NAN; + } + // Finish off the remaining values using normal instructions + for (dim_t i = n - num_remain; i < n; i++) + { + float absval = fabsf(*(x)); + if (fndMaxVal < absval || (isnan(absval) && !isnan(fndMaxVal))) + { + fndMaxVal = absval; + fndInd = i; + } + x += 1; + } + } + + // Issue vzeroupper instruction to clear upper lanes of ymm registers. + // This avoids a performance penalty caused by false dependencies when + // transitioning from from AVX to SSE instructions (which may occur + // later, especially if BLIS is compiled with -mfpmath=sse). + _mm256_zeroupper(); + + /* Store final index to output variable. */ + *i_max = fndInd; + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3) +} + +// ----------------------------------------------------------------------------- +/* Converts all the NAN to a negative number less than previously encountered NANs*/ +#define REMOVE_NAN_512D(reg_512) \ + { \ + __m512d sign_mask = _mm512_set1_pd( -0.0f ); \ + \ + /* Numbers other than NAN will become 0. */ \ + __m512d vec_mask = _mm512_mul_pd( reg_512, sign_mask ); \ + \ + /* Producing an 8-bit mask. */ \ + __m512i int_mask_vec = _mm512_castpd_si512( vec_mask ); \ + __mmask8 vec_mask8 = _mm512_movepi64_mask( int_mask_vec ); \ + \ + /* Replacing all the NAN with negative numbers. */ \ + reg_512 = _mm512_mask_blend_pd( vec_mask8, _mm512_set1_pd( nan_repl ), reg_512 ); \ + nan_repl = nan_repl - 1; \ + } + +//---------------------------------------------------------------------------------------------------- +void bli_damaxv_zen_int_avx512( + dim_t n, + double *restrict x, inc_t incx, + dim_t *restrict i_max, + cntx_t *restrict cntx) +{ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3) + double *minus_one = PASTEMAC(d, m1); + + // Used to replace NAN in registers. This value is decremented each time + // remove NAN is applied so as to keep the NAN value replacements unique. + double nan_repl = -1.0; + + dim_t *zero_i = PASTEMAC(i, 0); + + double chi1_r; + //double chi1_i; + double abs_chi1; + double abs_chi1_max; + dim_t i_max_l; + dim_t i; + + /* If the vector length is zero, return early. This directly emulates + the behavior of netlib BLAS's i?amax() routines. */ + if (bli_zero_dim1(n)) + { + PASTEMAC(i, copys) + (*zero_i, *i_max); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3) + return; + } + + /* Initialize the index of the maximum absolute value to zero. */ + PASTEMAC(i, copys) + (*zero_i, i_max_l); + + /* Initialize the maximum absolute value search candidate with + -1, which is guaranteed to be less than all values we will + compute. */ + PASTEMAC(d, copys) + (*minus_one, abs_chi1_max); + + // For non-unit strides, or very small vector lengths, compute with + // scalar code. + if (incx != 1 || n < 8) + { + for (i = 0; i < n; ++i) + { + double *chi1 = x + (i)*incx; + + /* Get the real and imaginary components of chi1. */ + chi1_r = *chi1; + + /* Replace chi1_r and chi1_i with their absolute values. */ + chi1_r = fabs(chi1_r); + + /* Add the real and imaginary absolute values together. */ + abs_chi1 = chi1_r; + + /* If the absolute value of the current element exceeds that of + the previous largest, save it and its index. If NaN is + encountered, then treat it the same as if it were a valid + value that was smaller than any previously seen. This + behavior mimics that of LAPACK's i?amax(). */ + if (abs_chi1_max < abs_chi1 || (isnan(abs_chi1) && !isnan(abs_chi1_max))) + { + abs_chi1_max = abs_chi1; + i_max_l = i; + } + } + } + else + { + + dim_t iterations, n_left, vector_length = 8, unrollCount = 0; + + //mask bits + __mmask8 mask_got_01, mask_got_23; + + //YMM0 - YMM6 registers + v4df_t max_hi, max_lo, max_ind_hi, max_ind_lo, + mask_final, inter_result, inter_ind; + + //XMM0 to XMM4 registers + v2dd_t max_vec_hi, max_vec_lo, max_ind_hi_128, + max_ind_lo_128, mask_vec_lo; + + //ZMM0 to ZMM13 registers + v8df_t zmm0, zmm1, zmm2, zmm3, zmm4_Ind, + zmm5_Ind, zmm6_Ind, zmm7_Ind, max_01, + max_23, final_max, max_array, max_ind, inc_vec; + + //ZMM14 to ZMM16 registers + __m512d mask_01, mask_23, sign_mask; + + //Intermediate int mask values + __m512i int_mask_01, int_mask_23; + + // Initialize sign mask + sign_mask = _mm512_set1_pd(-0.f); + + //Initializing the indexes of the base case of max vector + zmm4_Ind.v = _mm512_set_pd(7, 6, 5, 4, 3, 2, 1, 0); + inc_vec.v = _mm512_set1_pd(8); //Vector for incrementing + + // Initializing the max array as vec [ 0 : 512 ] + max_array.v = _mm512_loadu_pd(x); + + // Taking the absolute value and removing the NAN + max_array.v = _mm512_andnot_pd(sign_mask, max_array.v); + REMOVE_NAN_512D(max_array.v); + + // Initializing the maximumum index + max_ind.v = _mm512_set_pd(7, 6, 5, 4, 3, 2, 1, 0); + x += vector_length; + + //Incrementing to make the vector + //to point to the next 8 elements + zmm4_Ind.v = _mm512_add_pd(zmm4_Ind.v, inc_vec.v); + + /* Loop unrolled by a factor of 4 + At the end of the loop max_array holds the largest element + in each corresponding vector index */ + for (unrollCount = 8; (unrollCount + 31) < n; unrollCount += 32) + { + // Taking 32 elements + // Taking only the absolute values of the registers + // Removing the NAN values and replacing it + // with negative numbers + zmm0.v = _mm512_loadu_pd(x); + zmm0.v = _mm512_andnot_pd(sign_mask, zmm0.v); + REMOVE_NAN_512D(zmm0.v); + x += vector_length; + + zmm1.v = _mm512_loadu_pd(x); + zmm5_Ind.v = _mm512_add_pd(zmm4_Ind.v, inc_vec.v); + zmm1.v = _mm512_andnot_pd(sign_mask, zmm1.v); + REMOVE_NAN_512D(zmm1.v); + x += vector_length; + + zmm2.v = _mm512_loadu_pd(x); + zmm6_Ind.v = _mm512_add_pd(zmm5_Ind.v, inc_vec.v); + zmm2.v = _mm512_andnot_pd(sign_mask, zmm2.v); + REMOVE_NAN_512D(zmm2.v); + x += vector_length; + + zmm3.v = _mm512_loadu_pd(x); + zmm7_Ind.v = _mm512_add_pd(zmm6_Ind.v, inc_vec.v); + zmm3.v = _mm512_andnot_pd(sign_mask, zmm3.v); + REMOVE_NAN_512D(zmm3.v); + x += vector_length; + + /*Using sub function to generating the mask + as a 512d type*/ + mask_01 = _mm512_sub_pd(zmm0.v, zmm1.v); + mask_23 = _mm512_sub_pd(zmm2.v, zmm3.v); + + //Converting the 512d mask to a 512i mask + int_mask_01 = _mm512_castpd_si512(mask_01); + int_mask_23 = _mm512_castpd_si512(mask_23); + + /*Converting the 512i mask + to mmask type to use the mask bits*/ + mask_got_01 = _mm512_movepi64_mask(int_mask_01); + mask_got_23 = _mm512_movepi64_mask(int_mask_23); + + //Storing the largest elements in index % 8 position for + //vector 1 and 2, and the index of the corresponding element + max_01.v = _mm512_mask_blend_pd(mask_got_01, zmm0.v, zmm1.v); + zmm5_Ind.v = _mm512_mask_blend_pd(mask_got_01, zmm4_Ind.v, zmm5_Ind.v); + + //Storing the largest elements in index % 8 position for + //vector 3 and 4, and the index of the corresponding element + max_23.v = _mm512_mask_blend_pd(mask_got_23, zmm2.v, zmm3.v); + zmm6_Ind.v = _mm512_mask_blend_pd(mask_got_23, zmm6_Ind.v, zmm7_Ind.v); + + //Generating mask for the intermediate max vector + mask_01 = _mm512_sub_pd(max_01.v, max_23.v); + int_mask_01 = _mm512_castpd_si512(mask_01); + mask_got_01 = _mm512_movepi64_mask(int_mask_01); + + /*Storing the largest elements in index % 8 position for + the intermediate max vectors, + and the index of the corresponding element*/ + final_max.v = _mm512_mask_blend_pd(mask_got_01, max_01.v, max_23.v); + zmm5_Ind.v = _mm512_mask_blend_pd(mask_got_01, zmm5_Ind.v, zmm6_Ind.v); + + //Generating the mask for final max vector and base max vector + mask_01 = _mm512_sub_pd(max_array.v, final_max.v); + int_mask_01 = _mm512_castpd_si512(mask_01); + mask_got_01 = _mm512_movepi64_mask(int_mask_01); + + // Result is the maximum of all index % 8 locations + max_array.v = _mm512_mask_blend_pd(mask_got_01, max_array.v, final_max.v); + max_ind.v = _mm512_mask_blend_pd(mask_got_01, max_ind.v, zmm5_Ind.v); + + // Incrementing the index to point to the next 8 locations + zmm4_Ind.v = _mm512_add_pd(zmm7_Ind.v, inc_vec.v); + } + + // Calculating the number of iterations left + iterations = (n - unrollCount) / vector_length; + n_left = (n - unrollCount) % vector_length; + + /* At the end of the loop max_array holds the largest element + in each corresponding vector index */ + for (dim_t i = 0; i < iterations; ++i) + { + // Taking 32 elements + // Taking only the absolute values of the registers + // Removing the NAN values and replacing it + // with negative numbers + zmm0.v = _mm512_loadu_pd(x); + zmm0.v = _mm512_abs_pd(zmm0.v); + REMOVE_NAN_512D(zmm0.v); + + //Generating mask for the intermediate max vector + mask_01 = _mm512_sub_pd(max_array.v, zmm0.v); + int_mask_01 = _mm512_castpd_si512(mask_01); + mask_got_01 = _mm512_movepi64_mask(int_mask_01); + + // Result is the maximum of all index % 8 locations + max_array.v = _mm512_mask_blend_pd(mask_got_01, max_array.v, zmm0.v); + + //Storing the index of the corresponding max array elemets + max_ind.v = _mm512_mask_blend_pd(mask_got_01, max_ind.v, zmm4_Ind.v); + + //Incrementing the vector the point to the next location + //Incrementing the vector indexes + x += vector_length; + zmm4_Ind.v = _mm512_add_pd(zmm4_Ind.v, inc_vec.v); + } + + //Breaking max array into vectors of length 4 + //Taking upper and lower halves + max_hi.v = _mm512_extractf64x4_pd(max_array.v, 1); + max_ind_hi.v = _mm512_extractf64x4_pd(max_ind.v, 1); + max_lo.v = _mm512_extractf64x4_pd(max_array.v, 0); + max_ind_lo.v = _mm512_extractf64x4_pd(max_ind.v, 0); + + //Generating the mask for blending + mask_final.v = _mm256_sub_pd(max_hi.v, max_lo.v); + + // Storing the max of max array index % 4 + inter_result.v = _mm256_blendv_pd(max_hi.v, max_lo.v, mask_final.v); + inter_ind.v = _mm256_blendv_pd(max_ind_hi.v, max_ind_lo.v, mask_final.v); + + //Breaking max array into vectors of length 2 + max_vec_lo.v = _mm256_extractf128_pd(inter_result.v, 0); + max_vec_hi.v = _mm256_extractf128_pd(inter_result.v, 1); + max_ind_hi_128.v = _mm256_extractf128_pd(inter_ind.v, 1); + max_ind_lo_128.v = _mm256_extractf128_pd(inter_ind.v, 0); + + //Generating the mask for blending + mask_vec_lo.v = _mm_sub_pd(max_vec_lo.v, max_vec_hi.v); + + // Storing the max of max array index % 2 + max_vec_lo.v = _mm_blendv_pd(max_vec_lo.v, max_vec_hi.v, mask_vec_lo.v); + max_ind_lo_128.v = _mm_blendv_pd(max_ind_lo_128.v, max_ind_hi_128.v, mask_vec_lo.v); + + max_vec_hi.v = _mm_permute_pd(max_vec_lo.v, 1); + max_ind_hi_128.v = _mm_permute_pd(max_ind_lo_128.v, 1); + + //Performing work of CMP128 i.e generating mask + mask_vec_lo.v = _mm_sub_pd(max_vec_lo.v, max_vec_hi.v); + + //Finding the maximum element + max_vec_lo.v = _mm_blendv_pd(max_vec_lo.v, max_vec_hi.v, mask_vec_lo.v); + max_ind_lo_128.v = _mm_blendv_pd(max_ind_lo_128.v, max_ind_hi_128.v, mask_vec_lo.v); + + abs_chi1_max = max_vec_lo.d[0]; + + //If the largest number is negative it is NAN + if (abs_chi1_max < 0) + abs_chi1_max = NAN; + + i_max_l = max_ind_lo_128.d[0]; + + for (i = n - n_left; i < n; i++) + { + double *chi1 = x; + + /* Get the real and imaginary components of chi1. */ + chi1_r = *chi1; + + /* Replace chi1_r and chi1_i with their absolute values. */ + abs_chi1 = fabs(chi1_r); + + /* If the absolute value of the current element exceeds that of + the previous largest, save it and its index. If NaN is + encountered, return the index of the first NaN. This + behavior mimics that of LAPACK's i?amax(). */ + if (abs_chi1_max < abs_chi1 || (isnan(abs_chi1) && !isnan(abs_chi1_max))) + { + abs_chi1_max = abs_chi1; + i_max_l = i; + } + + x += 1; + } + } + + // Issue vzeroupper instruction to clear upper lanes of ymm registers. + // This avoids a performance penalty caused by false dependencies when + // transitioning from from AVX to SSE instructions (which may occur + // later, especially if BLIS is compiled with -mfpmath=sse). + _mm256_zeroupper(); + + // Return value + *i_max = i_max_l; + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3) +} diff --git a/kernels/zen4/3/CMakeLists.txt b/kernels/zen4/3/CMakeLists.txt new file mode 100644 index 0000000000..381204ae68 --- /dev/null +++ b/kernels/zen4/3/CMakeLists.txt @@ -0,0 +1,7 @@ +##Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved.## + +target_sources("${PROJECT_NAME}" + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemmtrsm_l_zen_16x14.c + ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemmtrsm_u_zen_16x14.c + ) diff --git a/kernels/zen4/3/bli_gemmtrsm_l_zen_16x14.c b/kernels/zen4/3/bli_gemmtrsm_l_zen_16x14.c new file mode 100644 index 0000000000..3680c44b05 --- /dev/null +++ b/kernels/zen4/3/bli_gemmtrsm_l_zen_16x14.c @@ -0,0 +1,1669 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + + +#include "blis.h" +#define BLIS_ASM_SYNTAX_ATT +#include "bli_x86_asm_macros.h" + +#define A_L1_PREFETCH_DIST 12 // in units of k iterations +#define B_L1_PREFETCH_DIST 12 // e.g. 4 k iterations ~= 56 cycles +#define TAIL_NITER 5 // in units of 4x unrolled k iterations + // e.g. 5 -> 4*5 k iterations ~= 280 cycles + +#define PREFETCH_A_L1(n, k) \ + PREFETCH(0, MEM(RAX, A_L1_PREFETCH_DIST*16*8 + (2*n+k)*64)) +#define PREFETCH_B_L1(n, k) \ + PREFETCH(0, MEM(RBX, B_L1_PREFETCH_DIST*14*8 + (2*n+k)*56)) + +#define LOOP_ALIGN ALIGN32 + +#define SUBITER(n) \ +\ + PREFETCH_A_L1(n, 0) \ + \ + VBROADCASTSD(MEM(RBX, (14*n + 0)*8), ZMM(2)) \ + VBROADCASTSD(MEM(RBX, (14*n + 1)*8), ZMM(3)) \ + VFMADD231PD(ZMM(0), ZMM(2), ZMM(4)) \ + VFMADD231PD(ZMM(1), ZMM(2), ZMM(5)) \ + VFMADD231PD(ZMM(0), ZMM(3), ZMM(6)) \ + VFMADD231PD(ZMM(1), ZMM(3), ZMM(7)) \ + \ + VBROADCASTSD(MEM(RBX, (14*n + 2)*8), ZMM(2)) \ + VBROADCASTSD(MEM(RBX, (14*n + 3)*8), ZMM(3)) \ + VFMADD231PD(ZMM(0), ZMM(2), ZMM(8) ) \ + VFMADD231PD(ZMM(1), ZMM(2), ZMM(9) ) \ + VFMADD231PD(ZMM(0), ZMM(3), ZMM(10)) \ + VFMADD231PD(ZMM(1), ZMM(3), ZMM(11)) \ + \ + PREFETCH_B_L1(n, 0) \ + \ + VBROADCASTSD(MEM(RBX, (14*n + 4)*8), ZMM(2)) \ + VBROADCASTSD(MEM(RBX, (14*n + 5)*8), ZMM(3)) \ + VFMADD231PD(ZMM(0), ZMM(2), ZMM(12)) \ + VFMADD231PD(ZMM(1), ZMM(2), ZMM(13)) \ + VFMADD231PD(ZMM(0), ZMM(3), ZMM(14)) \ + VFMADD231PD(ZMM(1), ZMM(3), ZMM(15)) \ + \ + VBROADCASTSD(MEM(RBX, (14*n + 6)*8), ZMM(2)) \ + VBROADCASTSD(MEM(RBX, (14*n + 7)*8), ZMM(3)) \ + VFMADD231PD(ZMM(0), ZMM(2), ZMM(16)) \ + VFMADD231PD(ZMM(1), ZMM(2), ZMM(17)) \ + VFMADD231PD(ZMM(0), ZMM(3), ZMM(18)) \ + VFMADD231PD(ZMM(1), ZMM(3), ZMM(19)) \ + \ + PREFETCH_A_L1(n, 1) \ + \ + VBROADCASTSD(MEM(RBX, (14*n + 8)*8), ZMM(2)) \ + VBROADCASTSD(MEM(RBX, (14*n + 9)*8), ZMM(3)) \ + VFMADD231PD(ZMM(0), ZMM(2), ZMM(20)) \ + VFMADD231PD(ZMM(1), ZMM(2), ZMM(21)) \ + VFMADD231PD(ZMM(0), ZMM(3), ZMM(22)) \ + VFMADD231PD(ZMM(1), ZMM(3), ZMM(23)) \ + \ + VBROADCASTSD(MEM(RBX, (14*n + 10)*8), ZMM(2)) \ + VBROADCASTSD(MEM(RBX, (14*n + 11)*8), ZMM(3)) \ + VFMADD231PD(ZMM(0), ZMM(2), ZMM(24)) \ + VFMADD231PD(ZMM(1), ZMM(2), ZMM(25)) \ + VFMADD231PD(ZMM(0), ZMM(3), ZMM(26)) \ + VFMADD231PD(ZMM(1), ZMM(3), ZMM(27)) \ + \ + PREFETCH_B_L1(n, 1) \ + \ + VBROADCASTSD(MEM(RBX, (14*n + 12)*8), ZMM(2)) \ + VBROADCASTSD(MEM(RBX, (14*n + 13)*8), ZMM(3)) \ + VFMADD231PD(ZMM(0), ZMM(2), ZMM(28)) \ + VFMADD231PD(ZMM(1), ZMM(2), ZMM(29)) \ + VFMADD231PD(ZMM(0), ZMM(3), ZMM(30)) \ + VFMADD231PD(ZMM(1), ZMM(3), ZMM(31)) \ + \ + VMOVAPD(MEM(RAX,((n*2)+2)*8*8), ZMM(0)) \ + VMOVAPD(MEM(RAX,((n*2)+3)*8*8), ZMM(1)) + +#define UPDATE_C_COL_SCATTERED(R1,R2) \ +\ + KXNORW(K(0), K(0), K(1)) \ + KXNORW(K(0), K(0), K(2)) \ + VSCATTERQPD(ZMM(R1), MEM(RCX,ZMM(2),1) MASK_K(1)) \ + VSCATTERQPD(ZMM(R2), MEM(R14,ZMM(2),1) MASK_K(2)) \ + ADD(RDI, RCX) \ + ADD(RDI, R14) \ + +/* scatter only first 6 elements of r1 and r2 */ +#define UPDATE_C_COL_SCATTERED_2x6(R1,R2) \ +\ + KXNORW(K(0), K(0), K(1)) \ + KXNORW(K(0), K(0), K(2)) \ + MOVQ(IMM(0b00111111), RAX) \ + KMOVQ(RAX, K(2)) \ + KMOVQ(RAX, K(1)) \ + VSCATTERQPD(ZMM(R1), MEM(RCX,ZMM(2),1) MASK_K(1)) \ + VSCATTERQPD(ZMM(R2), MEM(R14,ZMM(2),1) MASK_K(2)) \ + ADD(RDI, RCX) \ + ADD(RDI, R14) \ + +/* +Transpose 8 zmm registers and store the output in the given 8 registers + Note: Requires offsetPointer for scatter instruction + and 512 bytes of free memory (rcx) for transpose. + Input : + R1 = [ 0, 1, 2, 3, 4, 5, 6, 7] + R2 = [ 8, 9, 10, 11, 12, 13, 14, 15] + R3 = [16, 17, 18, 19, 20, 21, 22, 23] + R4 = [24, 25, 26, 27, 28, 29, 30, 31] + R5 = [32, 33, 34, 35, 36, 37, 38, 39] + R6 = [40, 41, 42, 43, 44, 45, 46, 47] + R7 = [48, 49, 50, 51, 52, 53, 54, 55] + R18= [56, 57, 58, 59, 60, 61, 62, 63] + Output : + R1 = [0, 8, 16, 24, 32, 40, 48, 56] + R2 = [1, 9, 17, 25, 33, 41, 49, 57] + R3 = [2, 10, 18, 26, 34, 42, 50, 58] + R4 = [3, 11, 19, 27, 35, 43, 51, 59] + R5 = [4, 12, 20, 28, 36, 44, 52, 60] + R6 = [5, 13, 21, 29, 37, 45, 53, 61] + R7 = [6, 14, 22, 30, 38, 46, 54, 62] + R18= [7, 15, 23, 31, 39, 47, 55, 63] +*/ +#define TRANSPOSE_REGISTERS_8x8(R1, R2, R3, R4, R5, R6, R7, R18) \ +\ + MOV(R8, RCX) \ + MOV(VAR(cs_c), RSI) \ + MOV(R9, RDI) \ + LEA(MEM(RCX, RSI, 8), RDX) \ + MOV(VAR(offsetPtr), R13) \ + MOV(RDI, R12) \ + CMP(RSI, R12) \ + CMOVL(RSI, R12) \ + VPBROADCASTQ(R12, ZMM(0)) \ + VPMULLQ(MEM(R13 ), ZMM(0), ZMM(2)) \ + \ + KXNORW(K(0), K(0), K(1)) \ + KXNORW(K(0), K(0), K(2)) \ + KXNORW(K(0), K(0), K(3)) \ + KXNORW(K(0), K(0), K(4)) \ + VSCATTERQPD(ZMM(R1), MEM(RCX,ZMM(2),1) MASK_K(1)) \ + ADD(IMM(1*8), RCX) \ + KXNORW(K(0), K(0), K(1)) \ + VSCATTERQPD(ZMM(R2), MEM(RCX,ZMM(2),1) MASK_K(2)) \ + ADD(IMM(1*8), RCX) \ + KXNORW(K(0), K(0), K(2)) \ + VSCATTERQPD(ZMM(R3), MEM(RCX,ZMM(2),1) MASK_K(3)) \ + ADD(IMM(1*8), RCX) \ + KXNORW(K(0), K(0), K(3)) \ + VSCATTERQPD(ZMM(R4), MEM(RCX,ZMM(2),1) MASK_K(4)) \ + ADD(IMM(1*8), RCX) \ + KXNORW(K(0), K(0), K(4)) \ + VSCATTERQPD(ZMM(R5), MEM(RCX,ZMM(2),1) MASK_K(1)) \ + ADD(IMM(1*8), RCX) \ + KXNORW(K(0), K(0), K(1)) \ + VSCATTERQPD(ZMM(R6), MEM(RCX,ZMM(2),1) MASK_K(2)) \ + ADD(IMM(1*8), RCX) \ + KXNORW(K(0), K(0), K(2)) \ + VSCATTERQPD(ZMM(R7), MEM(RCX,ZMM(2),1) MASK_K(3)) \ + ADD(IMM(1*8), RCX) \ + KXNORW(K(0), K(0), K(3)) \ + VSCATTERQPD(ZMM(R18), MEM(RCX,ZMM(2),1) MASK_K(4)) \ + \ + MOV(R8, RCX) \ + \ + VMOVUPD(MEM(RCX), ZMM(R1))\ + ADD(R12, RCX) \ + VMOVUPD(MEM(RCX), ZMM(R2)) \ + ADD(R12, RCX) \ + VMOVUPD(MEM(RCX), ZMM(R3)) \ + ADD(R12, RCX) \ + VMOVUPD(MEM(RCX), ZMM(R4)) \ + ADD(R12, RCX) \ + VMOVUPD(MEM(RCX), ZMM(R5)) \ + ADD(R12, RCX) \ + VMOVUPD(MEM(RCX), ZMM(R6)) \ + ADD(R12, RCX) \ + VMOVUPD(MEM(RCX), ZMM(R7)) \ + ADD(R12, RCX) \ + VMOVUPD(MEM(RCX), ZMM(R18)) \ + +/* +Transpose six zmm registers and store the output in the given 8 registers + Note: Require offsetPointer for scatter instruction + and 512 bytes of free memory (rcx) for transpose. + Input : + R1 = [ 0, 1, 2, 3, 4, 5, 6, 7] + R2 = [ 8, 9, 10, 11, 12, 13, 14, 15] + R3 = [16, 17, 18, 19, 20, 21, 22, 23] + R4 = [24, 25, 26, 27, 28, 29, 30, 31] + R5 = [32, 33, 34, 35, 36, 37, 38, 39] + R6 = [40, 41, 42, 43, 44, 45, 46, 47] + Output : + R1 = [0, 8, 16, 24, 32, 40, -, -] + R2 = [1, 9, 17, 25, 33, 41, -, -] + R3 = [2, 10, 18, 26, 34, 42, -, -] + R4 = [3, 11, 19, 27, 35, 43, -, -] + R5 = [4, 12, 20, 28, 36, 44, -, -] + R6 = [5, 13, 21, 29, 37, 45, -, -] + R7 = [6, 14, 22, 30, 38, 46, -, -] + R18 = [7, 15, 23, 31, 39, 47, -, -] +*/ +#define TRANSPOSE_REGISTERS_6x8(R1, R2, R3, R4, R5, R6, R7, R18) \ +\ + MOV(R8, RCX) \ + MOV(VAR(cs_c), RSI) \ + MOV(R9, RDI) \ + LEA(MEM(RCX, RSI, 8), RDX) \ + MOV(VAR(offsetPtr), R13) \ + MOV(RDI, R12) \ + CMP(RSI, R12) \ + CMOVL(RSI, R12) \ + VPBROADCASTQ(R12, ZMM(0)) \ + VPMULLQ(MEM(R13 ), ZMM(0), ZMM(2)) \ + LEA(MEM(RCX, R12, 4), RCX) \ + LEA(MEM(RCX, R12, 1), RCX) \ + \ + KXNORW(K(0), K(0), K(1)) \ + KXNORW(K(0), K(0), K(2)) \ + KXNORW(K(0), K(0), K(3)) \ + KXNORW(K(0), K(0), K(4)) \ + VSCATTERQPD(ZMM(R1), MEM(RCX,ZMM(2),1) MASK_K(1)) \ + ADD(IMM(1*8), RCX) \ + KXNORW(K(0), K(0), K(1)) \ + VSCATTERQPD(ZMM(R2), MEM(RCX,ZMM(2),1) MASK_K(2)) \ + ADD(IMM(1*8), RCX) \ + KXNORW(K(0), K(0), K(2)) \ + VSCATTERQPD(ZMM(R3), MEM(RCX,ZMM(2),1) MASK_K(3)) \ + ADD(IMM(1*8), RCX) \ + KXNORW(K(0), K(0), K(3)) \ + VSCATTERQPD(ZMM(R4), MEM(RCX,ZMM(2),1) MASK_K(4)) \ + ADD(IMM(1*8), RCX) \ + KXNORW(K(0), K(0), K(4)) \ + VSCATTERQPD(ZMM(R5), MEM(RCX,ZMM(2),1) MASK_K(1)) \ + ADD(IMM(1*8), RCX) \ + KXNORW(K(0), K(0), K(1)) \ + VSCATTERQPD(ZMM(R6), MEM(RCX,ZMM(2),1) MASK_K(2)) \ + \ + MOV(R8, RCX) \ + LEA(MEM(RCX, R12, 4), RCX) \ + LEA(MEM(RCX, R12, 1), RCX) \ + \ + VMOVUPD(MEM(RCX), ZMM(R1))\ + ADD(R12, RCX) \ + VMOVUPD(MEM(RCX), ZMM(R2)) \ + ADD(R12, RCX) \ + VMOVUPD(MEM(RCX), ZMM(R3)) \ + ADD(R12, RCX) \ + VMOVUPD(MEM(RCX), ZMM(R4)) \ + ADD(R12, RCX) \ + VMOVUPD(MEM(RCX), ZMM(R5)) \ + ADD(R12, RCX) \ + VMOVUPD(MEM(RCX), ZMM(R6)) \ + ADD(R12, RCX) \ + VMOVUPD(MEM(RCX), ZMM(R7)) \ + ADD(R12, RCX) \ + VMOVUPD(MEM(RCX), ZMM(R18)) \ + +// Offsets for scatter/gather instructions +static int64_t offsets[16] __attribute__((aligned(64))) = + { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,10,11,12,13,14,15}; + + +void bli_dgemmtrsm_l_zen_asm_16x14 +( + dim_t k_, + double* restrict alpha, + double* restrict a10, + double* restrict a11, + double* restrict b01, + double* restrict b11, + double* restrict c11, inc_t rs_c_, inc_t cs_c_, + auxinfo_t* restrict data, + cntx_t* restrict cntx +) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_9); + const int64_t k = k_; + uint64_t rs_c = rs_c_ * 8; + const int64_t* offsetPtr = &offsets[0]; + uint64_t cs_c = cs_c_ * 8; + + BEGIN_ASM() + + //clear out registers + VXORPD(ZMM(4), ZMM(4), ZMM(4)) + VMOVAPD(ZMM(4), ZMM(5) ) + VMOVAPD(ZMM(4), ZMM(6) ) + VMOVAPD(ZMM(4), ZMM(7) ) + VMOVAPD(ZMM(4), ZMM(8) ) + VMOVAPD(ZMM(4), ZMM(9) ) + VMOVAPD(ZMM(4), ZMM(10)) + VMOVAPD(ZMM(4), ZMM(11)) + VMOVAPD(ZMM(4), ZMM(12)) + VMOVAPD(ZMM(4), ZMM(13)) + VMOVAPD(ZMM(4), ZMM(14)) + VMOVAPD(ZMM(4), ZMM(15)) + VMOVAPD(ZMM(4), ZMM(16)) + VMOVAPD(ZMM(4), ZMM(17)) + VMOVAPD(ZMM(4), ZMM(18)) + VMOVAPD(ZMM(4), ZMM(19)) + VMOVAPD(ZMM(4), ZMM(20)) + VMOVAPD(ZMM(4), ZMM(21)) + VMOVAPD(ZMM(4), ZMM(22)) + VMOVAPD(ZMM(4), ZMM(23)) + VMOVAPD(ZMM(4), ZMM(24)) + VMOVAPD(ZMM(4), ZMM(25)) + VMOVAPD(ZMM(4), ZMM(26)) + VMOVAPD(ZMM(4), ZMM(27)) + VMOVAPD(ZMM(4), ZMM(28)) + VMOVAPD(ZMM(4), ZMM(29)) + VMOVAPD(ZMM(4), ZMM(30)) + VMOVAPD(ZMM(4), ZMM(31)) + + MOV(VAR(k), RSI) + + MOV(VAR(a10), RAX) // load address of a + MOV(VAR(b01), RBX) // load address of b + MOV(VAR(c11), R8) // load address of c + + LEA(MEM(RSI,RSI,2), RDX) + LEA(MEM(,RDX,4), RDX) + LEA(MEM(RDX,RSI,4), RDX) // 16 * K + LEA(MEM(RAX,RDX,8,-128), RDX) // a_next + LEA(MEM(R8,63), R12) // c for prefetching + + MOV(IMM(14), RDI) + LEA(MEM(, RDI, 8), RDI) + + MOV(VAR(rs_c), R9) + MOV(VAR(cs_c), R13) + + MOV(IMM(0), R11) + MOV(R13, R15) + + CMP(IMM(8), R13) + JNE(.DBEFORELOOP) + MOV(IMM(2), R11) + MOV(R9, R15) + + LABEL(.DBEFORELOOP) + + VMOVAPD(MEM(RAX, 0*8*8), ZMM(0)) + VMOVAPD(MEM(RAX, 1*8*8), ZMM(1)) // preload a + + MOV(RSI, R10) + AND(IMM(3), R10) // R10 = K % 4 + SAR(IMM(2), RSI) // RSI = K / 4 + + /* + MAIN LOOP + Note: This loop runs (K/4 - 14 - TAIL_NITER) times + */ + SUB(R11, RSI) + SUB(IMM(14+TAIL_NITER), RSI) + JLE(K_LE_80) + + LOOP_ALIGN + LABEL(LOOP1) + SUBITER(0) + PREFETCH(1, MEM(RDX)) + SUBITER(1) + SUB(IMM(1), RSI) + SUBITER(2) + PREFETCH(1, MEM(RDX,64)) + SUBITER(3) + + LEA(MEM(RAX,4*16*8), RAX) + LEA(MEM(RBX,4*14*8), RBX) + LEA(MEM(RDX,16*8), RDX) + + JNZ(LOOP1) + + LABEL(K_LE_80) + + /* + C prefetch Loop + Note: This loop runs 14 times, + These 14 iterations are done seperately so that c11 can be prefetched here. + */ + ADD(R11, RSI) + ADD(IMM(14), RSI) + JLE(K_LE_24) + + LOOP_ALIGN + LABEL(LOOP2) + PREFETCH(0, MEM(R12)) + SUBITER(0) + PREFETCH(1, MEM(RDX)) + SUBITER(1) + PREFETCH(0, MEM(R12,64)) + SUB(IMM(1), RSI) + SUBITER(2) + PREFETCH(1, MEM(RDX,64)) + SUBITER(3) + + LEA(MEM(RAX,4*16*8), RAX) + LEA(MEM(RBX,4*14*8), RBX) + LEA(MEM(RDX,16*8), RDX) + LEA(MEM(R12,R15,1), R12) + + JNZ(LOOP2) + + LABEL(K_LE_24) + + /* + TAIL_NITER Loop + Note: This loop runs TAIL_NITER times, + This loop is used to provide some distance between c11 prefetch and usage of c11. + */ + ADD(IMM(0+TAIL_NITER), RSI) + JLE(TAIL) + + LOOP_ALIGN + LABEL(LOOP3) + + SUBITER(0) + PREFETCH(1, MEM(RDX)) + SUBITER(1) + SUB(IMM(1), RSI) + SUBITER(2) + PREFETCH(1, MEM(RDX,64)) + SUBITER(3) + + LEA(MEM(RAX,4*16*8), RAX) + LEA(MEM(RBX,4*14*8), RBX) + LEA(MEM(RDX,16*8), RDX) + + JNZ(LOOP3) + + /* + K Left Loop + This loop runs K % 4 times. + */ + LABEL(TAIL) + MOV(R10, RSI) + TEST(RSI, RSI) + JE(.DPOSTACCUM) + LOOP_ALIGN + LABEL(TAIL_LOOP) + + SUB(IMM(1), RSI) + SUBITER(0) + + LEA(MEM(RAX,16*8), RAX) + LEA(MEM(RBX,14*8), RBX) + + JNZ(TAIL_LOOP) + + LABEL(.DPOSTACCUM) + /* GEMM output before transpose GEMM output after transpose + __________________________________ + ___________________________ |______zmm4______|______zmm20___x x| + | | | | | | | | | | | | | | | |______zmm6______|______zmm22___x x| + |z|z|z|z|z|z|z|z|z|z|z|z|z|z| |______zmm8______|______zmm24___x x| + |m|m|m|m|m|m|m|m|m|m|m|m|m|m| |______zmm10_____|______zmm26___x x| + |m|m|m|m|m|m|m|m|m|m|m|m|m|m| |______zmm12_____|______zmm28___x x| + |4|6|8|1|1|1|1|1|2|2|2|2|2|3| |______zmm14_____|______zmm30___x x| + | | | |0|2|4|6|8|0|2|4|6|8|0| |______zmm16_____|_____c11______x x| + | | | | | | | | | | | | | | | |______zmm18_____|_____c11+cs___x x| + ____________________________ |______zmm5______|______zmm21___x x| + | | | | | | | | | | | | | | | |______zmm7______|______zmm23___x x| + |z|z|z|z|z|z|z|z|z|z|z|z|z|z| |______zmm9______|______zmm25___x x| + |m|m|m|m|m|m|m|m|m|m|m|m|m|m| |______zmm11_____|______zmm27___x x| + |m|m|m|m|m|m|m|m|m|m|m|m|m|m| |______zmm13_____|______zmm29___x x| + |5|7|9|1|1|1|1|1|2|2|2|2|2|3| |______zmm15_____|______zmm31___x x| + | | | |1|3|5|7|9|1|3|5|7|9|1| |______zmm17_____|____c11+cs*2__x x| + | | | | | | | | | | | | | | | |______zmm19_____|____c11+cs*4__x x| + _____________________________ + */ + TRANSPOSE_REGISTERS_8x8(4, 6, 8, 10, 12, 14, 16, 18) // transpose the output of GEMM + TRANSPOSE_REGISTERS_8x8(5, 7, 9, 11, 13, 15, 17, 19) + TRANSPOSE_REGISTERS_6x8(20, 22, 24, 26, 28, 30, 0, 1) + VMOVUPD(ZMM(0), MEM(R8 )) + VMOVUPD(ZMM(1), MEM(R8, R12, 1)) // zmm0 and zmm1 are needed for other computations, + // therefore store zmm0, zmm1 's data in rcx + TRANSPOSE_REGISTERS_6x8(21, 23, 25, 27, 29, 31, 0, 1) + VMOVUPD(ZMM(0), MEM(R8, R12, 2)) + VMOVUPD(ZMM(1), MEM(R8, R12, 4)) // zmm0 and zmm1 are needed for other computations, + // therefore store zmm0, zmm1 's data in rcx + MOV(IMM(14), RDI) + LEA(MEM(, RDI, 8), RDI) + + MOV(VAR(alpha), RBX) + VBROADCASTSD(MEM(RBX), ZMM(3)) + + MOV(IMM(1), RSI) + LEA(MEM(, RSI, 8), RSI) + + MOV(VAR(b11), RCX) + LEA(MEM(RCX, RSI, 8), RDX) + + MOV(RCX, R11) + MOV(RDX, R14) + + // Scale by Alpha + VFMSUB231PD(MEM(RCX), ZMM(3), ZMM(4)) + ADD(RDI, RCX) + VFMSUB231PD(MEM(RCX), ZMM(3), ZMM(6)) + ADD(RDI, RCX) + VFMSUB231PD(MEM(RCX), ZMM(3), ZMM(8)) + ADD(RDI, RCX) + VFMSUB231PD(MEM(RCX), ZMM(3), ZMM(10)) + ADD(RDI, RCX) + VFMSUB231PD(MEM(RCX), ZMM(3), ZMM(12)) + ADD(RDI, RCX) + VFMSUB231PD(MEM(RCX), ZMM(3), ZMM(14)) + ADD(RDI, RCX) + VFMSUB231PD(MEM(RCX), ZMM(3), ZMM(16)) + ADD(RDI, RCX) + VFMSUB231PD(MEM(RCX), ZMM(3), ZMM(18)) + ADD(RDI, RCX) + + VFMSUB231PD(MEM(RCX), ZMM(3), ZMM(5)) + ADD(RDI, RCX) + VFMSUB231PD(MEM(RCX), ZMM(3), ZMM(7)) + ADD(RDI, RCX) + VFMSUB231PD(MEM(RCX), ZMM(3), ZMM(9)) + ADD(RDI, RCX) + VFMSUB231PD(MEM(RCX), ZMM(3), ZMM(11)) + ADD(RDI, RCX) + VFMSUB231PD(MEM(RCX), ZMM(3), ZMM(13)) + ADD(RDI, RCX) + VFMSUB231PD(MEM(RCX), ZMM(3), ZMM(15)) + ADD(RDI, RCX) + VFMSUB231PD(MEM(RCX), ZMM(3), ZMM(17)) + ADD(RDI, RCX) + VFMSUB231PD(MEM(RCX), ZMM(3), ZMM(19)) + + + + VFMSUB231PD(MEM(RDX), ZMM(3), ZMM(20)) + ADD(RDI, RDX) + VFMSUB231PD(MEM(RDX), ZMM(3), ZMM(22)) + ADD(RDI, RDX) + VFMSUB231PD(MEM(RDX), ZMM(3), ZMM(24)) + ADD(RDI, RDX) + VFMSUB231PD(MEM(RDX), ZMM(3), ZMM(26)) + ADD(RDI, RDX) + VFMSUB231PD(MEM(RDX), ZMM(3), ZMM(28)) + ADD(RDI, RDX) + VFMSUB231PD(MEM(RDX), ZMM(3), ZMM(30)) + ADD(RDI, RDX) + VMOVUPD(MEM(R8 ), ZMM(0)) + VMOVUPD(MEM(R8, R12, 1), ZMM(1)) + VFMSUB231PD(MEM(RDX), ZMM(3), ZMM(0)) + VMOVUPD(ZMM(0), MEM(R8 )) + ADD(RDI, RDX) + VFMSUB231PD(MEM(RDX), ZMM(3), ZMM(1)) + VMOVUPD(ZMM(1), MEM(R8, R12, 1)) + ADD(RDI, RDX) + + VFMSUB231PD(MEM(RDX), ZMM(3), ZMM(21)) + ADD(RDI, RDX) + VFMSUB231PD(MEM(RDX), ZMM(3), ZMM(23)) + ADD(RDI, RDX) + VFMSUB231PD(MEM(RDX), ZMM(3), ZMM(25)) + ADD(RDI, RDX) + VFMSUB231PD(MEM(RDX), ZMM(3), ZMM(27)) + ADD(RDI, RDX) + VFMSUB231PD(MEM(RDX), ZMM(3), ZMM(29)) + ADD(RDI, RDX) + VFMSUB231PD(MEM(RDX), ZMM(3), ZMM(31)) + ADD(RDI, RDX) + VMOVUPD(MEM(R8, R12, 2), ZMM(0)) + VMOVUPD(MEM(R8, R12, 4), ZMM(1)) + VFMSUB231PD(MEM(RDX), ZMM(3), ZMM(0)) + VMOVUPD(ZMM(0), MEM(R8, R12, 2)) + ADD(RDI, RDX) + VFMSUB231PD(MEM(RDX), ZMM(3), ZMM(1)) + VMOVUPD(ZMM(1), MEM(R8, R12, 4)) + + /* + TRSM region + Each row requires 1 iteration, therefore 16 iterations are present + */ + MOV(VAR(a11), RAX) + MOV(R11, RCX) + MOV(R14, RDX) + + + //iteration 0 -------------------------------------------- + VBROADCASTSD(MEM(RAX, (0+0*16)*8), ZMM(0)) + + #ifdef BLIS_ENABLE_TRSM_PREINVERSION + VMULPD(ZMM(0), ZMM(4) , ZMM(4) ) + VMULPD(ZMM(0), ZMM(20), ZMM(20)) + #else + VDIVPD(ZMM(0), ZMM(4) , ZMM(4) ) + VDIVPD(ZMM(0), ZMM(20), ZMM(20)) + #endif + + VMOVUPD(ZMM(4), MEM(RCX)) + VEXTRACTF64X4(IMM(1), ZMM(20), YMM(0)) + VMOVUPD(YMM(20), MEM(RDX )) + VMOVUPD(XMM(0) , MEM(RDX,4*8)) // move only first six values to rcx + ADD(RDI, RCX) + ADD(RDI, RDX) + + //iteration 1 -------------------------------------------- + VBROADCASTSD(MEM(RAX, (1+0*16)*8), ZMM(0)) + VBROADCASTSD(MEM(RAX, (1+1*16)*8), ZMM(1)) + + VMULPD(ZMM(0), ZMM(4) , ZMM(2)) + VMULPD(ZMM(0), ZMM(20), ZMM(3)) + + VSUBPD(ZMM(2), ZMM(6) , ZMM(6) ) + VSUBPD(ZMM(3), ZMM(22), ZMM(22)) + + #ifdef BLIS_ENABLE_TRSM_PREINVERSION + VMULPD(ZMM(1), ZMM(6) , ZMM(6) ) + VMULPD(ZMM(1), ZMM(22), ZMM(22)) + #else + VDIVPD(ZMM(1), ZMM(6) , ZMM(6)) + VDIVPD(ZMM(1), ZMM(22), ZMM(22)) + #endif + + VMOVUPD(ZMM(6), MEM(RCX)) + VEXTRACTF64X4(IMM(1), ZMM(22), YMM(0)) + VMOVUPD(YMM(22), MEM(RDX )) + VMOVUPD(XMM(0) , MEM(RDX,4*8)) + ADD(RDI, RCX) + ADD(RDI, RDX) + + //iteration 2 -------------------------------------------- + VBROADCASTSD(MEM(RAX, (2+0*16)*8), ZMM(0)) + VBROADCASTSD(MEM(RAX, (2+1*16)*8), ZMM(1)) + + VMULPD(ZMM(0), ZMM(4) , ZMM(2)) + VMULPD(ZMM(0), ZMM(20), ZMM(3)) + + + VBROADCASTSD(MEM(RAX, (2+2*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(6) , ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(22), ZMM(3)) + + VSUBPD(ZMM(2), ZMM(8) , ZMM(8) ) + VSUBPD(ZMM(3), ZMM(24), ZMM(24)) + + #ifdef BLIS_ENABLE_TRSM_PREINVERSION + VMULPD(ZMM(0), ZMM(8) , ZMM(8) ) + VMULPD(ZMM(0), ZMM(24), ZMM(24)) + #else + VDIVPD(ZMM(0), ZMM(8) , ZMM(8) ) + VDIVPD(ZMM(0), ZMM(24), ZMM(24)) + #endif + + VMOVUPD(ZMM(8), MEM(RCX)) + VEXTRACTF64X4(IMM(1), ZMM(24), YMM(0)) + VMOVUPD(YMM(24), MEM(RDX )) + VMOVUPD(XMM(0) , MEM(RDX,4*8)) + ADD(RDI, RCX) + ADD(RDI, RDX) + + //iteration 3 -------------------------------------------- + VBROADCASTSD(MEM(RAX, (3+0*16)*8), ZMM(0)) + VBROADCASTSD(MEM(RAX, (3+1*16)*8), ZMM(1)) + + VMULPD(ZMM(0), ZMM(4) , ZMM(2)) + VMULPD(ZMM(0), ZMM(20), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (3+2*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(6) , ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(22), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (3+3*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(8) , ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(24), ZMM(3)) + + VSUBPD(ZMM(2), ZMM(10), ZMM(10)) + VSUBPD(ZMM(3), ZMM(26), ZMM(26)) + + #ifdef BLIS_ENABLE_TRSM_PREINVERSION + VMULPD(ZMM(1), ZMM(10), ZMM(10)) + VMULPD(ZMM(1), ZMM(26), ZMM(26)) + #else + VDIVPD(ZMM(1), ZMM(10), ZMM(10)) + VDIVPD(ZMM(1), ZMM(26), ZMM(26)) + #endif + + VMOVUPD(ZMM(10), MEM(RCX)) + VEXTRACTF64X4(IMM(1), ZMM(26), YMM(0)) + VMOVUPD(YMM(26), MEM(RDX )) + VMOVUPD(XMM(0) , MEM(RDX,4*8)) + ADD(RDI, RCX) + ADD(RDI, RDX) + + //iteration 4 -------------------------------------------- + VBROADCASTSD(MEM(RAX, (4+0*16)*8), ZMM(0)) + VBROADCASTSD(MEM(RAX, (4+1*16)*8), ZMM(1)) + + VMULPD(ZMM(0), ZMM(4) , ZMM(2)) + VMULPD(ZMM(0), ZMM(20), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (4+2*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(6) , ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(22), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (4+3*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(8) , ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(24), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (4+4*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(10), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(26), ZMM(3)) + + VSUBPD(ZMM(2), ZMM(12), ZMM(12)) + VSUBPD(ZMM(3), ZMM(28), ZMM(28)) + + #ifdef BLIS_ENABLE_TRSM_PREINVERSION + VMULPD(ZMM(0), ZMM(12), ZMM(12)) + VMULPD(ZMM(0), ZMM(28), ZMM(28)) + #else + VDIVPD(ZMM(0), ZMM(12), ZMM(12)) + VDIVPD(ZMM(0), ZMM(28), ZMM(28)) + #endif + + VMOVUPD(ZMM(12), MEM(RCX)) + VEXTRACTF64X4(IMM(1), ZMM(28), YMM(0)) + VMOVUPD(YMM(28), MEM(RDX )) + VMOVUPD(XMM(0) , MEM(RDX,4*8)) + ADD(RDI, RCX) + ADD(RDI, RDX) + + //iteration 5 -------------------------------------------- + VBROADCASTSD(MEM(RAX, (5+0*16)*8), ZMM(0)) + VBROADCASTSD(MEM(RAX, (5+1*16)*8), ZMM(1)) + + VMULPD(ZMM(0), ZMM(4) , ZMM(2)) + VMULPD(ZMM(0), ZMM(20), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (5+2*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(6) , ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(22), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (5+3*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(8) , ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(24), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (5+4*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(10), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(26), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (5+5*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(12), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(28), ZMM(3)) + + VSUBPD(ZMM(2), ZMM(14), ZMM(14)) + VSUBPD(ZMM(3), ZMM(30), ZMM(30)) + + #ifdef BLIS_ENABLE_TRSM_PREINVERSION + VMULPD(ZMM(1), ZMM(14), ZMM(14)) + VMULPD(ZMM(1), ZMM(30), ZMM(30)) + #else + VDIVPD(ZMM(1), ZMM(14), ZMM(14)) + VDIVPD(ZMM(1), ZMM(30), ZMM(30)) + #endif + + VMOVUPD(ZMM(14), MEM(RCX)) + VEXTRACTF64X4(IMM(1), ZMM(30), YMM(0)) + VMOVUPD(YMM(30), MEM(RDX )) + VMOVUPD(XMM(0) , MEM(RDX,4*8)) + ADD(RDI, RCX) + ADD(RDI, RDX) + + //iteration 6 -------------------------------------------- + VBROADCASTSD(MEM(RAX, (6+0*16)*8), ZMM(0)) + VBROADCASTSD(MEM(RAX, (6+1*16)*8), ZMM(1)) + + VMULPD(ZMM(0), ZMM(4) , ZMM(2)) + VMULPD(ZMM(0), ZMM(20), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (6+2*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(6) , ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(22), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (6+3*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(8) , ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(24), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (6+4*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(10), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(26), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (6+5*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(12), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(28), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (6+6*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(14), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(30), ZMM(3)) + + VMOVUPD(MEM(R8), ZMM(1)) + VSUBPD(ZMM(2), ZMM(16), ZMM(16)) + VSUBPD(ZMM(3), ZMM(1) , ZMM(1) ) + + #ifdef BLIS_ENABLE_TRSM_PREINVERSION + VMULPD(ZMM(0), ZMM(16), ZMM(16)) + VMULPD(ZMM(0), ZMM(1) , ZMM(1) ) + #else + VDIVPD(ZMM(0), ZMM(16), ZMM(16)) + VDIVPD(ZMM(0), ZMM(1) , ZMM(1) ) + #endif + + VMOVUPD(ZMM(1), MEM(R8 )) + + VMOVUPD(ZMM(16), MEM(RCX)) + VEXTRACTF64X4(IMM(1), ZMM(1), YMM(0)) + VMOVUPD(YMM(1), MEM(RDX )) + VMOVUPD(XMM(0), MEM(RDX,4*8)) + ADD(RDI, RCX) + ADD(RDI, RDX) + + //iteration 7 -------------------------------------------- + VBROADCASTSD(MEM(RAX, (7+0*16)*8), ZMM(0)) + VBROADCASTSD(MEM(RAX, (7+1*16)*8), ZMM(1)) + + VMULPD(ZMM(0), ZMM(4) , ZMM(2)) + VMULPD(ZMM(0), ZMM(20), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (7+2*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(6) , ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(22), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (7+3*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(8) , ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(24), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (7+4*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(10), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(26), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (7+5*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(12), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(28), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (7+6*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(14), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(30), ZMM(3)) + + VMOVUPD(MEM(R8 ), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(16), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(1) , ZMM(3)) + VBROADCASTSD(MEM(RAX, (7+7*16)*8), ZMM(1)) + + VMOVUPD(MEM(R8, R12, 1), ZMM(0)) + VSUBPD(ZMM(2), ZMM(18), ZMM(18)) + VSUBPD(ZMM(3), ZMM(0) , ZMM(0) ) + + #ifdef BLIS_ENABLE_TRSM_PREINVERSION + VMULPD(ZMM(1), ZMM(18), ZMM(18)) + VMULPD(ZMM(1), ZMM(0) , ZMM(0) ) + #else + VDIVPD(ZMM(1), ZMM(18), ZMM(18)) + VDIVPD(ZMM(1), ZMM(0) , ZMM(0) ) + #endif + + VMOVUPD(ZMM(0), MEM(R8, R12, 1)) + + VMOVUPD(ZMM(18), MEM(RCX)) + VEXTRACTF64X4(IMM(1), ZMM(0), YMM(1)) + VMOVUPD(YMM(0), MEM(RDX )) + VMOVUPD(XMM(1), MEM(RDX,4*8)) + ADD(RDI, RCX) + ADD(RDI, RDX) + + //iteration 8 -------------------------------------------- + VBROADCASTSD(MEM(RAX, (8+0*16)*8), ZMM(0)) + VBROADCASTSD(MEM(RAX, (8+1*16)*8), ZMM(1)) + + VMULPD(ZMM(0), ZMM(4) , ZMM(2)) + VMULPD(ZMM(0), ZMM(20), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (8+2*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(6) , ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(22), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (8+3*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(8) , ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(24), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (8+4*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(10), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(26), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (8+5*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(12), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(28), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (8+6*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(14), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(30), ZMM(3)) + + VMOVUPD(MEM(R8 ), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(16), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(1) , ZMM(3)) + VBROADCASTSD(MEM(RAX, (8+7*16)*8), ZMM(1)) + + VMOVUPD(MEM(R8, R12, 1), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(18), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(0) , ZMM(3)) + VBROADCASTSD(MEM(RAX, (8+8*16)*8), ZMM(0)) + + VSUBPD(ZMM(2), ZMM(5) , ZMM(5) ) + VSUBPD(ZMM(3), ZMM(21), ZMM(21)) + + #ifdef BLIS_ENABLE_TRSM_PREINVERSION + VMULPD(ZMM(0), ZMM(5), ZMM(5)) + VMULPD(ZMM(0), ZMM(21), ZMM(21)) + #else + VDIVPD(ZMM(0), ZMM(5), ZMM(5)) + VDIVPD(ZMM(0), ZMM(21), ZMM(21)) + #endif + + VMOVUPD(ZMM(5), MEM(RCX)) + VEXTRACTF64X4(IMM(1), ZMM(21), YMM(1)) + VMOVUPD(YMM(21), MEM(RDX )) + VMOVUPD(XMM(1), MEM(RDX,4*8)) + ADD(RDI, RCX) + ADD(RDI, RDX) + + //iteration 9 -------------------------------------------- + VBROADCASTSD(MEM(RAX, (9+0*16)*8), ZMM(0)) + VBROADCASTSD(MEM(RAX, (9+1*16)*8), ZMM(1)) + + VMULPD(ZMM(0), ZMM(4) , ZMM(2)) + VMULPD(ZMM(0), ZMM(20), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (9+2*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(6) , ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(22), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (9+3*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(8) , ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(24), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (9+4*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(10), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(26), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (9+5*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(12), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(28), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (9+6*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(14), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(30), ZMM(3)) + + VMOVUPD(MEM(R8 ), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(16), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(1) , ZMM(3)) + VBROADCASTSD(MEM(RAX, (9+7*16)*8), ZMM(1)) + + VMOVUPD(MEM(R8, R12, 1), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(18), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(0) , ZMM(3)) + VBROADCASTSD(MEM(RAX, (9+8*16)*8), ZMM(0)) + + VBROADCASTSD(MEM(RAX, (9+9*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(5), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(21), ZMM(3)) + + VSUBPD(ZMM(2), ZMM(7), ZMM(7)) + VSUBPD(ZMM(3), ZMM(23), ZMM(23)) + + #ifdef BLIS_ENABLE_TRSM_PREINVERSION + VMULPD(ZMM(1), ZMM(7), ZMM(7)) + VMULPD(ZMM(1), ZMM(23), ZMM(23)) + #else + VDIVPD(ZMM(1), ZMM(7), ZMM(7)) + VDIVPD(ZMM(1), ZMM(23), ZMM(23)) + #endif + + VMOVUPD(ZMM(7), MEM(RCX)) + VEXTRACTF64X4(IMM(1), ZMM(23), YMM(1)) + VMOVUPD(YMM(23), MEM(RDX )) + VMOVUPD(XMM(1), MEM(RDX,4*8)) + ADD(RDI, RCX) + ADD(RDI, RDX) + + //iteration 10 -------------------------------------------- + VBROADCASTSD(MEM(RAX, (10+0*16)*8), ZMM(0)) + VBROADCASTSD(MEM(RAX, (10+1*16)*8), ZMM(1)) + + VMULPD(ZMM(0), ZMM(4) , ZMM(2)) + VMULPD(ZMM(0), ZMM(20), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (10+2*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(6) , ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(22), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (10+3*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(8) , ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(24), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (10+4*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(10), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(26), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (10+5*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(12), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(28), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (10+6*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(14), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(30), ZMM(3)) + + VMOVUPD(MEM(R8 ), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(16), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(1) , ZMM(3)) + VBROADCASTSD(MEM(RAX, (10+7*16)*8), ZMM(1)) + + VMOVUPD(MEM(R8, R12, 1), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(18), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(0) , ZMM(3)) + VBROADCASTSD(MEM(RAX, (10+8*16)*8), ZMM(0)) + + VBROADCASTSD(MEM(RAX, (10+9*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(5), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(21), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (10+10*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(7), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(23), ZMM(3)) + + VSUBPD(ZMM(2), ZMM(9), ZMM(9)) + VSUBPD(ZMM(3), ZMM(25), ZMM(25)) + + #ifdef BLIS_ENABLE_TRSM_PREINVERSION + VMULPD(ZMM(0), ZMM(9), ZMM(9)) + VMULPD(ZMM(0), ZMM(25), ZMM(25)) + #else + VDIVPD(ZMM(0), ZMM(9), ZMM(9)) + VDIVPD(ZMM(0), ZMM(25), ZMM(25)) + #endif + + VMOVUPD(ZMM(9), MEM(RCX)) + VEXTRACTF64X4(IMM(1), ZMM(25), YMM(1)) + VMOVUPD(YMM(25), MEM(RDX )) + VMOVUPD(XMM(1), MEM(RDX,4*8)) + ADD(RDI, RCX) + ADD(RDI, RDX) + + //iteration 11 -------------------------------------------- + VBROADCASTSD(MEM(RAX, (11+0*16)*8), ZMM(0)) + VBROADCASTSD(MEM(RAX, (11+1*16)*8), ZMM(1)) + + VMULPD(ZMM(0), ZMM(4) , ZMM(2)) + VMULPD(ZMM(0), ZMM(20), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (11+2*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(6) , ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(22), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (11+3*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(8) , ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(24), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (11+4*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(10), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(26), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (11+5*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(12), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(28), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (11+6*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(14), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(30), ZMM(3)) + + VMOVUPD(MEM(R8 ), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(16), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(1) , ZMM(3)) + VBROADCASTSD(MEM(RAX, (11+7*16)*8), ZMM(1)) + + VMOVUPD(MEM(R8, R12, 1), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(18), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(0) , ZMM(3)) + VBROADCASTSD(MEM(RAX, (11+8*16)*8), ZMM(0)) + + VBROADCASTSD(MEM(RAX, (11+9*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(5), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(21), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (11+10*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(7), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(23), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (11+11*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(9), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(25), ZMM(3)) + + VSUBPD(ZMM(2), ZMM(11), ZMM(11)) + VSUBPD(ZMM(3), ZMM(27), ZMM(27)) + + #ifdef BLIS_ENABLE_TRSM_PREINVERSION + VMULPD(ZMM(1), ZMM(11), ZMM(11)) + VMULPD(ZMM(1), ZMM(27), ZMM(27)) + #else + VDIVPD(ZMM(1), ZMM(11), ZMM(11)) + VDIVPD(ZMM(1), ZMM(27), ZMM(27)) + #endif + + VMOVUPD(ZMM(11), MEM(RCX)) + VEXTRACTF64X4(IMM(1), ZMM(27), YMM(1)) + VMOVUPD(YMM(27), MEM(RDX )) + VMOVUPD(XMM(1), MEM(RDX,4*8)) + ADD(RDI, RCX) + ADD(RDI, RDX) + + //iteration 12 -------------------------------------------- + VBROADCASTSD(MEM(RAX, (12+0*16)*8), ZMM(0)) + VBROADCASTSD(MEM(RAX, (12+1*16)*8), ZMM(1)) + + VMULPD(ZMM(0), ZMM(4) , ZMM(2)) + VMULPD(ZMM(0), ZMM(20), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (12+2*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(6) , ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(22), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (12+3*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(8) , ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(24), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (12+4*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(10), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(26), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (12+5*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(12), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(28), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (12+6*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(14), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(30), ZMM(3)) + + VMOVUPD(MEM(R8 ), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(16), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(1) , ZMM(3)) + VBROADCASTSD(MEM(RAX, (12+7*16)*8), ZMM(1)) + + VMOVUPD(MEM(R8, R12, 1), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(18), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(0) , ZMM(3)) + VBROADCASTSD(MEM(RAX, (12+8*16)*8), ZMM(0)) + + VBROADCASTSD(MEM(RAX, (12+9*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(5), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(21), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (12+10*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(7), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(23), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (12+11*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(9), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(25), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (12+12*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(11), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(27), ZMM(3)) + + VSUBPD(ZMM(2), ZMM(13), ZMM(13)) + VSUBPD(ZMM(3), ZMM(29), ZMM(29)) + + #ifdef BLIS_ENABLE_TRSM_PREINVERSION + VMULPD(ZMM(0), ZMM(13), ZMM(13)) + VMULPD(ZMM(0), ZMM(29), ZMM(29)) + #else + VDIVPD(ZMM(0), ZMM(13), ZMM(13)) + VDIVPD(ZMM(0), ZMM(29), ZMM(29)) + #endif + + VMOVUPD(ZMM(13), MEM(RCX)) + VEXTRACTF64X4(IMM(1), ZMM(29), YMM(1)) + VMOVUPD(YMM(29), MEM(RDX )) + VMOVUPD(XMM(1), MEM(RDX,4*8)) + ADD(RDI, RCX) + ADD(RDI, RDX) + + //iteration 13 -------------------------------------------- + VBROADCASTSD(MEM(RAX, (13+0*16)*8), ZMM(0)) + VBROADCASTSD(MEM(RAX, (13+1*16)*8), ZMM(1)) + + VMULPD(ZMM(0), ZMM(4) , ZMM(2)) + VMULPD(ZMM(0), ZMM(20), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (13+2*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(6) , ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(22), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (13+3*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(8) , ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(24), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (13+4*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(10), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(26), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (13+5*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(12), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(28), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (13+6*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(14), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(30), ZMM(3)) + + VMOVUPD(MEM(R8 ), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(16), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(1) , ZMM(3)) + VBROADCASTSD(MEM(RAX, (13+7*16)*8), ZMM(1)) + + VMOVUPD(MEM(R8, R12, 1), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(18), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(0) , ZMM(3)) + VBROADCASTSD(MEM(RAX, (13+8*16)*8), ZMM(0)) + + VBROADCASTSD(MEM(RAX, (13+9*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(5), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(21), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (13+10*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(7), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(23), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (13+11*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(9), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(25), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (13+12*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(11), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(27), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (13+13*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(13), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(29), ZMM(3)) + + VSUBPD(ZMM(2), ZMM(15), ZMM(15)) + VSUBPD(ZMM(3), ZMM(31), ZMM(31)) + + #ifdef BLIS_ENABLE_TRSM_PREINVERSION + VMULPD(ZMM(1), ZMM(15), ZMM(15)) + VMULPD(ZMM(1), ZMM(31), ZMM(31)) + #else + VDIVPD(ZMM(1), ZMM(15), ZMM(15)) + VDIVPD(ZMM(1), ZMM(31), ZMM(31)) + #endif + + VMOVUPD(ZMM(15), MEM(RCX)) + VEXTRACTF64X4(IMM(1), ZMM(31), YMM(1)) + VMOVUPD(YMM(31), MEM(RDX )) + VMOVUPD(XMM(1), MEM(RDX,4*8)) + ADD(RDI, RCX) + ADD(RDI, RDX) + + //iteration 14 -------------------------------------------- + VBROADCASTSD(MEM(RAX, (14+0*16)*8), ZMM(0)) + VBROADCASTSD(MEM(RAX, (14+1*16)*8), ZMM(1)) + + VMULPD(ZMM(0), ZMM(4) , ZMM(2)) + VMULPD(ZMM(0), ZMM(20), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (14+2*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(6) , ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(22), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (14+3*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(8) , ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(24), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (14+4*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(10), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(26), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (14+5*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(12), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(28), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (14+6*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(14), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(30), ZMM(3)) + + VMOVUPD(MEM(R8 ), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(16), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(1) , ZMM(3)) + VBROADCASTSD(MEM(RAX, (14+7*16)*8), ZMM(1)) + + VMOVUPD(MEM(R8, R12, 1), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(18), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(0) , ZMM(3)) + VBROADCASTSD(MEM(RAX, (14+8*16)*8), ZMM(0)) + + VBROADCASTSD(MEM(RAX, (14+9*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(5), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(21), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (14+10*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(7), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(23), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (14+11*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(9), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(25), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (14+12*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(11), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(27), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (14+13*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(13), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(29), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (14+14*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(15), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(31), ZMM(3)) + + VMOVUPD(MEM(R8, R12, 2), ZMM(1)) + VSUBPD(ZMM(2), ZMM(17), ZMM(17)) + VSUBPD(ZMM(3), ZMM(1) , ZMM(1) ) + + #ifdef BLIS_ENABLE_TRSM_PREINVERSION + VMULPD(ZMM(0), ZMM(17), ZMM(17)) + VMULPD(ZMM(0), ZMM(1) , ZMM(1) ) + #else + VDIVPD(ZMM(0), ZMM(17), ZMM(17)) + VDIVPD(ZMM(0), ZMM(1) , ZMM(1) ) + #endif + + VMOVUPD(ZMM(1), MEM(R8, R12, 2)) + + VMOVUPD(ZMM(17), MEM(RCX)) + VEXTRACTF64X4(IMM(1), ZMM(1), YMM(0)) + VMOVUPD(YMM(1), MEM(RDX )) + VMOVUPD(XMM(0), MEM(RDX,4*8)) + ADD(RDI, RCX) + ADD(RDI, RDX) + + //iteration 15 -------------------------------------------- + VBROADCASTSD(MEM(RAX, (15+0*16)*8), ZMM(0)) + VBROADCASTSD(MEM(RAX, (15+1*16)*8), ZMM(1)) + + VMULPD(ZMM(0), ZMM(4) , ZMM(2)) + VMULPD(ZMM(0), ZMM(20), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (15+2*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(6) , ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(22), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (15+3*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(8) , ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(24), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (15+4*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(10), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(26), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (15+5*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(12), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(28), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (15+6*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(14), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(30), ZMM(3)) + + VMOVUPD(MEM(R8 ), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(16), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(1) , ZMM(3)) + VBROADCASTSD(MEM(RAX, (15+7*16)*8), ZMM(1)) + + VMOVUPD(MEM(R8, R12, 1), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(18), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(0) , ZMM(3)) + VBROADCASTSD(MEM(RAX, (15+8*16)*8), ZMM(0)) + + VBROADCASTSD(MEM(RAX, (15+9*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(5), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(21), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (15+10*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(7), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(23), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (15+11*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(9), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(25), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (15+12*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(11), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(27), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (15+13*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(13), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(29), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (15+14*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(15), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(31), ZMM(3)) + + VMOVUPD(MEM(R8, R12, 2), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(17), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(1 ), ZMM(3)) + VBROADCASTSD(MEM(RAX, (15+15*16)*8), ZMM(1)) + + VMOVUPD(MEM(R8, R12, 4), ZMM(0)) + VSUBPD(ZMM(2), ZMM(19), ZMM(19)) + VSUBPD(ZMM(3), ZMM(0) , ZMM(0) ) + + #ifdef BLIS_ENABLE_TRSM_PREINVERSION + VMULPD(ZMM(1), ZMM(19), ZMM(19)) + VMULPD(ZMM(1), ZMM(0) , ZMM(0) ) + #else + VDIVPD(ZMM(1), ZMM(19), ZMM(19)) + VDIVPD(ZMM(1), ZMM(0) , ZMM(0) ) + #endif + + VMOVUPD(ZMM(0), MEM(R8, R12, 4)) + + VMOVUPD(ZMM(19), MEM(RCX)) + VEXTRACTF64X4(IMM(1), ZMM(0), YMM(1)) + VMOVUPD(YMM(0), MEM(RDX )) + VMOVUPD(XMM(1), MEM(RDX,4*8)) + ADD(RDI, RCX) + ADD(RDI, RDX) + + + /* + Storage Region (Post TRSM) + */ + MOV(R8, RCX) + MOV(R9, RDI) + MOV(VAR(cs_c), RSI) + + LEA(MEM(RCX, RSI, 8), RDX) + LEA(MEM(RCX, RDI, 8), R14) + + LEA(MEM(RSI, RSI, 2), R12) + LEA(MEM(RSI, RSI, 4), R13) + LEA(MEM(R13, RSI, 2), R15) + + CMP(IMM(8), RSI) + JZ(.DROWSTORED) + + CMP(IMM(8), RDI) + JZ(.DCOLSTORED) + + LABEL(.DROWSTORED) + + VMOVUPD(MEM(R8 ), ZMM(1)) + VMOVUPD(ZMM(4), MEM(RCX)) + VMOVUPD(MEM(R8, RDI, 1), ZMM(4)) + + ADD(RDI, RCX) + VEXTRACTF64X4(IMM(1), ZMM(20), YMM(0)) + VMOVUPD(YMM(20), MEM(RDX )) + + VMOVUPD(MEM(R8, RDI, 2), ZMM(20)) + + VMOVUPD(XMM(0) , MEM(RDX,4*8)) + ADD(RDI, RDX) + + VMOVUPD(ZMM(6), MEM(RCX)) + + VMOVUPD(MEM(R8, RDI, 4), ZMM(6)) + + ADD(RDI, RCX) + VEXTRACTF64X4(IMM(1), ZMM(22), YMM(0)) + VMOVUPD(YMM(22), MEM(RDX )) + VMOVUPD(XMM(0) , MEM(RDX,4*8)) + ADD(RDI, RDX) + + VMOVUPD(ZMM(8), MEM(RCX)) + ADD(RDI, RCX) + VEXTRACTF64X4(IMM(1), ZMM(24), YMM(0)) + VMOVUPD(YMM(24), MEM(RDX )) + VMOVUPD(XMM(0) , MEM(RDX,4*8)) + ADD(RDI, RDX) + + VMOVUPD(ZMM(10), MEM(RCX)) + ADD(RDI, RCX) + VEXTRACTF64X4(IMM(1), ZMM(26), YMM(0)) + VMOVUPD(YMM(26), MEM(RDX )) + VMOVUPD(XMM(0) , MEM(RDX,4*8)) + ADD(RDI, RDX) + + VMOVUPD(ZMM(12), MEM(RCX)) + ADD(RDI, RCX) + VEXTRACTF64X4(IMM(1), ZMM(28), YMM(0)) + VMOVUPD(YMM(28), MEM(RDX )) + VMOVUPD(XMM(0) , MEM(RDX,4*8)) + ADD(RDI, RDX) + + VMOVUPD(ZMM(14), MEM(RCX)) + ADD(RDI, RCX) + VEXTRACTF64X4(IMM(1), ZMM(30), YMM(0)) + VMOVUPD(YMM(30), MEM(RDX )) + VMOVUPD(XMM(0) , MEM(RDX,4*8)) + ADD(RDI, RDX) + + VMOVUPD(ZMM(16), MEM(RCX)) + ADD(RDI, RCX) + VEXTRACTF64X4(IMM(1), ZMM(1), YMM(0)) + VMOVUPD(YMM(1), MEM(RDX )) + VMOVUPD(XMM(0), MEM(RDX,4*8)) + ADD(RDI, RDX) + + VMOVUPD(ZMM(18), MEM(RCX)) + ADD(RDI, RCX) + VEXTRACTF64X4(IMM(1), ZMM(4), YMM(0)) + VMOVUPD(YMM(4), MEM(RDX )) + VMOVUPD(XMM(0), MEM(RDX,4*8)) + ADD(RDI, RDX) + + VMOVUPD(ZMM(5), MEM(RCX)) + ADD(RDI, RCX) + VEXTRACTF64X4(IMM(1), ZMM(21), YMM(0)) + VMOVUPD(YMM(21), MEM(RDX )) + VMOVUPD(XMM(0) , MEM(RDX,4*8)) + ADD(RDI, RDX) + + VMOVUPD(ZMM(7), MEM(RCX)) + ADD(RDI, RCX) + VEXTRACTF64X4(IMM(1), ZMM(23), YMM(0)) + VMOVUPD(YMM(23), MEM(RDX )) + VMOVUPD(XMM(0) , MEM(RDX,4*8)) + ADD(RDI, RDX) + + VMOVUPD(ZMM(9), MEM(RCX)) + ADD(RDI, RCX) + VEXTRACTF64X4(IMM(1), ZMM(25), YMM(0)) + VMOVUPD(YMM(25), MEM(RDX )) + VMOVUPD(XMM(0) , MEM(RDX,4*8)) + ADD(RDI, RDX) + + VMOVUPD(ZMM(11), MEM(RCX)) + ADD(RDI, RCX) + VEXTRACTF64X4(IMM(1), ZMM(27), YMM(0)) + VMOVUPD(YMM(27), MEM(RDX )) + VMOVUPD(XMM(0) , MEM(RDX,4*8)) + ADD(RDI, RDX) + + VMOVUPD(ZMM(13), MEM(RCX)) + ADD(RDI, RCX) + VEXTRACTF64X4(IMM(1), ZMM(29), YMM(0)) + VMOVUPD(YMM(29), MEM(RDX )) + VMOVUPD(XMM(0) , MEM(RDX,4*8)) + ADD(RDI, RDX) + + VMOVUPD(ZMM(15), MEM(RCX)) + ADD(RDI, RCX) + VEXTRACTF64X4(IMM(1), ZMM(31), YMM(0)) + VMOVUPD(YMM(31), MEM(RDX )) + VMOVUPD(XMM(0) , MEM(RDX,4*8)) + ADD(RDI, RDX) + + VMOVUPD(ZMM(17), MEM(RCX)) + ADD(RDI, RCX) + VEXTRACTF64X4(IMM(1), ZMM(20), YMM(0)) + VMOVUPD(YMM(20), MEM(RDX )) + VMOVUPD(XMM(0), MEM(RDX,4*8)) + ADD(RDI, RDX) + + VMOVUPD(ZMM(19), MEM(RCX)) + VEXTRACTF64X4(IMM(1), ZMM(6), YMM(0)) + VMOVUPD(YMM(6), MEM(RDX )) + VMOVUPD(XMM(0), MEM(RDX,4*8)) + + + JMP(.DDONE) + LABEL(.DCOLSTORED) + + + MOV(VAR(offsetPtr), R12) + LEA(MEM(RCX, RSI, 8), RDX) + VPBROADCASTQ(RSI, ZMM(0)) + VPMULLQ(MEM(R12), ZMM(0), ZMM(2)) + VPMULLQ(MEM(R12,64), ZMM(0), ZMM(3)) + + VMOVUPD(MEM(RCX ), ZMM(0)) + VMOVUPD(MEM(RCX, RSI, 1), ZMM(1)) + + MOV(RDX, RCX) + LEA(MEM(RCX, RDI, 8), R14) + UPDATE_C_COL_SCATTERED_2x6(20,21) + VMOVUPD(MEM(R8, RSI, 2), ZMM(20)) + VMOVUPD(MEM(R8, RSI, 4), ZMM(21)) + UPDATE_C_COL_SCATTERED_2x6(22,23) + UPDATE_C_COL_SCATTERED_2x6(24,25) + UPDATE_C_COL_SCATTERED_2x6(26,27) + UPDATE_C_COL_SCATTERED_2x6(28,29) + UPDATE_C_COL_SCATTERED_2x6(30,31) + UPDATE_C_COL_SCATTERED_2x6(0 ,20 ) + UPDATE_C_COL_SCATTERED_2x6(1 ,21 ) + + MOV(R8, RCX) + LEA(MEM(RCX, RDI, 8), R14) + UPDATE_C_COL_SCATTERED( 4, 5) + UPDATE_C_COL_SCATTERED( 6, 7) + UPDATE_C_COL_SCATTERED( 8, 9) + UPDATE_C_COL_SCATTERED(10,11) + UPDATE_C_COL_SCATTERED(12,13) + UPDATE_C_COL_SCATTERED(14,15) + UPDATE_C_COL_SCATTERED(16,17) + UPDATE_C_COL_SCATTERED(18,19) + + + LABEL(.DDONE) + + VZEROUPPER() + + end_asm( + : // output operands (none) + : // input operands + [a10] "m" (a10), // 1 + [k] "m" (k), // 2 + [b01] "m" (b01), // 3 + [a11] "m" (a11), // 6 + [b11] "m" (b11), // 7 + [c11] "m" (c11), // 8 + [rs_c] "m" (rs_c), // 9 + [cs_c] "m" (cs_c), // 10, + [alpha] "m" (alpha), + [offsetPtr] "m" (offsetPtr) + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rdi", "rsi", "r8", "r9", "r10", "r11", "r12", + "r13", "r14", "r15", "zmm0", "zmm1", "zmm2", "zmm3", "zmm4", "zmm5", + "zmm6", "zmm7", "zmm8", "zmm9", "zmm10", "zmm11", "zmm12", "zmm13", + "zmm14", "zmm15", "zmm16", "zmm17", "zmm18", "zmm19", "zmm20", "zmm21", + "zmm22", "zmm23", "zmm24", "zmm25", "zmm26", "zmm27", "zmm28", "zmm29", + "zmm30", "zmm31", "memory" + ) + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_9); +} \ No newline at end of file diff --git a/kernels/zen4/3/bli_gemmtrsm_u_zen_16x14.c b/kernels/zen4/3/bli_gemmtrsm_u_zen_16x14.c new file mode 100644 index 0000000000..787f85155c --- /dev/null +++ b/kernels/zen4/3/bli_gemmtrsm_u_zen_16x14.c @@ -0,0 +1,1706 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" +#define BLIS_ASM_SYNTAX_ATT +#include "bli_x86_asm_macros.h" + +#define A_L1_PREFETCH_DIST 12 // in units of k iterations +#define B_L1_PREFETCH_DIST 12 // e.g. 4 k iterations ~= 56 cycles +#define TAIL_NITER 5 // in units of 4x unrolled k iterations + // e.g. 5 -> 4*5 k iterations ~= 280 cycles + +#define PREFETCH_A_L1(n, k) \ + PREFETCH(0, MEM(RAX, A_L1_PREFETCH_DIST*16*8 + (2*n+k)*64)) +#define PREFETCH_B_L1(n, k) \ + PREFETCH(0, MEM(RBX, B_L1_PREFETCH_DIST*14*8 + (2*n+k)*56)) + +#define LOOP_ALIGN ALIGN32 + +#define SUBITER(n) \ +\ + PREFETCH_A_L1(n, 0) \ + \ + VBROADCASTSD(MEM(RBX, (14*n + 0)*8), ZMM(2)) \ + VBROADCASTSD(MEM(RBX, (14*n + 1)*8), ZMM(3)) \ + VFMADD231PD(ZMM(0), ZMM(2), ZMM(4)) \ + VFMADD231PD(ZMM(1), ZMM(2), ZMM(5)) \ + VFMADD231PD(ZMM(0), ZMM(3), ZMM(6)) \ + VFMADD231PD(ZMM(1), ZMM(3), ZMM(7)) \ + \ + VBROADCASTSD(MEM(RBX, (14*n + 2)*8), ZMM(2)) \ + VBROADCASTSD(MEM(RBX, (14*n + 3)*8), ZMM(3)) \ + VFMADD231PD(ZMM(0), ZMM(2), ZMM(8) ) \ + VFMADD231PD(ZMM(1), ZMM(2), ZMM(9) ) \ + VFMADD231PD(ZMM(0), ZMM(3), ZMM(10)) \ + VFMADD231PD(ZMM(1), ZMM(3), ZMM(11)) \ + \ + PREFETCH_B_L1(n, 0) \ + \ + VBROADCASTSD(MEM(RBX, (14*n + 4)*8), ZMM(2)) \ + VBROADCASTSD(MEM(RBX, (14*n + 5)*8), ZMM(3)) \ + VFMADD231PD(ZMM(0), ZMM(2), ZMM(12)) \ + VFMADD231PD(ZMM(1), ZMM(2), ZMM(13)) \ + VFMADD231PD(ZMM(0), ZMM(3), ZMM(14)) \ + VFMADD231PD(ZMM(1), ZMM(3), ZMM(15)) \ + \ + VBROADCASTSD(MEM(RBX, (14*n + 6)*8), ZMM(2)) \ + VBROADCASTSD(MEM(RBX, (14*n + 7)*8), ZMM(3)) \ + VFMADD231PD(ZMM(0), ZMM(2), ZMM(16)) \ + VFMADD231PD(ZMM(1), ZMM(2), ZMM(17)) \ + VFMADD231PD(ZMM(0), ZMM(3), ZMM(18)) \ + VFMADD231PD(ZMM(1), ZMM(3), ZMM(19)) \ + \ + PREFETCH_A_L1(n, 1) \ + \ + VBROADCASTSD(MEM(RBX, (14*n + 8)*8), ZMM(2)) \ + VBROADCASTSD(MEM(RBX, (14*n + 9)*8), ZMM(3)) \ + VFMADD231PD(ZMM(0), ZMM(2), ZMM(20)) \ + VFMADD231PD(ZMM(1), ZMM(2), ZMM(21)) \ + VFMADD231PD(ZMM(0), ZMM(3), ZMM(22)) \ + VFMADD231PD(ZMM(1), ZMM(3), ZMM(23)) \ + \ + VBROADCASTSD(MEM(RBX, (14*n + 10)*8), ZMM(2)) \ + VBROADCASTSD(MEM(RBX, (14*n + 11)*8), ZMM(3)) \ + VFMADD231PD(ZMM(0), ZMM(2), ZMM(24)) \ + VFMADD231PD(ZMM(1), ZMM(2), ZMM(25)) \ + VFMADD231PD(ZMM(0), ZMM(3), ZMM(26)) \ + VFMADD231PD(ZMM(1), ZMM(3), ZMM(27)) \ + \ + PREFETCH_B_L1(n, 1) \ + \ + VBROADCASTSD(MEM(RBX, (14*n + 12)*8), ZMM(2)) \ + VBROADCASTSD(MEM(RBX, (14*n + 13)*8), ZMM(3)) \ + VFMADD231PD(ZMM(0), ZMM(2), ZMM(28)) \ + VFMADD231PD(ZMM(1), ZMM(2), ZMM(29)) \ + VFMADD231PD(ZMM(0), ZMM(3), ZMM(30)) \ + VFMADD231PD(ZMM(1), ZMM(3), ZMM(31)) \ + \ + VMOVAPD(MEM(RAX,((n*2)+2)*8*8), ZMM(0)) \ + VMOVAPD(MEM(RAX,((n*2)+3)*8*8), ZMM(1)) + +#define UPDATE_C_COL_SCATTERED(R1,R2) \ +\ + KXNORW(K(0), K(0), K(1)) \ + KXNORW(K(0), K(0), K(2)) \ + VSCATTERQPD(ZMM(R1), MEM(RCX,ZMM(2),1) MASK_K(1)) \ + VSCATTERQPD(ZMM(R2), MEM(R14,ZMM(2),1) MASK_K(2)) \ + ADD(RDI, RCX) \ + ADD(RDI, R14) \ + +/* scatter only first 6 elements of r1 and r2 */ +#define UPDATE_C_COL_SCATTERED_2x6(R1,R2) \ +\ + KXNORW(K(0), K(0), K(1)) \ + KXNORW(K(0), K(0), K(2)) \ + MOVQ(IMM(0b00111111), RAX) \ + KMOVQ(RAX, K(2)) \ + KMOVQ(RAX, K(1)) \ + VSCATTERQPD(ZMM(R1), MEM(RCX,ZMM(2),1) MASK_K(1)) \ + VSCATTERQPD(ZMM(R2), MEM(R14,ZMM(2),1) MASK_K(2)) \ + ADD(RDI, RCX) \ + ADD(RDI, R14) \ + +/* +Transpose 8 zmm registers and store the output in the given 8 registers + Note: Requires offsetPointer for scatter instruction + and 512 bytes of free memory (rcx) for transpose. + Input : + R1 = [ 0, 1, 2, 3, 4, 5, 6, 7] + R2 = [ 8, 9, 10, 11, 12, 13, 14, 15] + R3 = [16, 17, 18, 19, 20, 21, 22, 23] + R4 = [24, 25, 26, 27, 28, 29, 30, 31] + R5 = [32, 33, 34, 35, 36, 37, 38, 39] + R6 = [40, 41, 42, 43, 44, 45, 46, 47] + R7 = [48, 49, 50, 51, 52, 53, 54, 55] + R18= [56, 57, 58, 59, 60, 61, 62, 63] + Output : + R1 = [0, 8, 16, 24, 32, 40, 48, 56] + R2 = [1, 9, 17, 25, 33, 41, 49, 57] + R3 = [2, 10, 18, 26, 34, 42, 50, 58] + R4 = [3, 11, 19, 27, 35, 43, 51, 59] + R5 = [4, 12, 20, 28, 36, 44, 52, 60] + R6 = [5, 13, 21, 29, 37, 45, 53, 61] + R7 = [6, 14, 22, 30, 38, 46, 54, 62] + R18= [7, 15, 23, 31, 39, 47, 55, 63] +*/ +#define TRANSPOSE_REGISTERS_8x8(R1, R2, R3, R4, R5, R6, R7, R18) \ +\ + MOV(R8, RCX) \ + MOV(VAR(cs_c), RSI) \ + MOV(R9, RDI) \ + LEA(MEM(RCX, RSI, 8), RDX) \ + MOV(VAR(offsetPtr), R13) \ + MOV(RDI, R12) \ + CMP(RSI, R12) \ + CMOVL(RSI, R12) \ + VPBROADCASTQ(R12, ZMM(0)) \ + VPMULLQ(MEM(R13 ), ZMM(0), ZMM(2)) \ + \ + KXNORW(K(0), K(0), K(1)) \ + KXNORW(K(0), K(0), K(2)) \ + KXNORW(K(0), K(0), K(3)) \ + KXNORW(K(0), K(0), K(4)) \ + VSCATTERQPD(ZMM(R1), MEM(RCX,ZMM(2),1) MASK_K(1)) \ + ADD(IMM(1*8), RCX) \ + KXNORW(K(0), K(0), K(1)) \ + VSCATTERQPD(ZMM(R2), MEM(RCX,ZMM(2),1) MASK_K(2)) \ + ADD(IMM(1*8), RCX) \ + KXNORW(K(0), K(0), K(2)) \ + VSCATTERQPD(ZMM(R3), MEM(RCX,ZMM(2),1) MASK_K(3)) \ + ADD(IMM(1*8), RCX) \ + KXNORW(K(0), K(0), K(3)) \ + VSCATTERQPD(ZMM(R4), MEM(RCX,ZMM(2),1) MASK_K(4)) \ + ADD(IMM(1*8), RCX) \ + KXNORW(K(0), K(0), K(4)) \ + VSCATTERQPD(ZMM(R5), MEM(RCX,ZMM(2),1) MASK_K(1)) \ + ADD(IMM(1*8), RCX) \ + KXNORW(K(0), K(0), K(1)) \ + VSCATTERQPD(ZMM(R6), MEM(RCX,ZMM(2),1) MASK_K(2)) \ + ADD(IMM(1*8), RCX) \ + KXNORW(K(0), K(0), K(2)) \ + VSCATTERQPD(ZMM(R7), MEM(RCX,ZMM(2),1) MASK_K(3)) \ + ADD(IMM(1*8), RCX) \ + KXNORW(K(0), K(0), K(3)) \ + VSCATTERQPD(ZMM(R18), MEM(RCX,ZMM(2),1) MASK_K(4)) \ + \ + MOV(R8, RCX) \ + \ + VMOVUPD(MEM(RCX), ZMM(R1))\ + ADD(R12, RCX) \ + VMOVUPD(MEM(RCX), ZMM(R2)) \ + ADD(R12, RCX) \ + VMOVUPD(MEM(RCX), ZMM(R3)) \ + ADD(R12, RCX) \ + VMOVUPD(MEM(RCX), ZMM(R4)) \ + ADD(R12, RCX) \ + VMOVUPD(MEM(RCX), ZMM(R5)) \ + ADD(R12, RCX) \ + VMOVUPD(MEM(RCX), ZMM(R6)) \ + ADD(R12, RCX) \ + VMOVUPD(MEM(RCX), ZMM(R7)) \ + ADD(R12, RCX) \ + VMOVUPD(MEM(RCX), ZMM(R18)) \ + +/* +Transpose six zmm registers and store the output in the given 8 registers + Note: Require offsetPointer for scatter instruction + and 512 bytes of free memory (rcx) for transpose. + Input : + R1 = [ 0, 1, 2, 3, 4, 5, 6, 7] + R2 = [ 8, 9, 10, 11, 12, 13, 14, 15] + R3 = [16, 17, 18, 19, 20, 21, 22, 23] + R4 = [24, 25, 26, 27, 28, 29, 30, 31] + R5 = [32, 33, 34, 35, 36, 37, 38, 39] + R6 = [40, 41, 42, 43, 44, 45, 46, 47] + Output : + R1 = [0, 8, 16, 24, 32, 40, -, -] + R2 = [1, 9, 17, 25, 33, 41, -, -] + R3 = [2, 10, 18, 26, 34, 42, -, -] + R4 = [3, 11, 19, 27, 35, 43, -, -] + R5 = [4, 12, 20, 28, 36, 44, -, -] + R6 = [5, 13, 21, 29, 37, 45, -, -] + R7 = [6, 14, 22, 30, 38, 46, -, -] + R18 = [7, 15, 23, 31, 39, 47, -, -] +*/ +#define TRANSPOSE_REGISTERS_6x8(R1, R2, R3, R4, R5, R6, R7, R18) \ +\ + MOV(R8, RCX) \ + MOV(VAR(cs_c), RSI) \ + MOV(R9, RDI) \ + LEA(MEM(RCX, RSI, 8), RDX) \ + MOV(VAR(offsetPtr), R13) \ + MOV(RDI, R12) \ + CMP(RSI, R12) \ + CMOVL(RSI, R12) \ + VPBROADCASTQ(R12, ZMM(0)) \ + VPMULLQ(MEM(R13 ), ZMM(0), ZMM(2)) \ + LEA(MEM(RCX, R12, 4), RCX) \ + LEA(MEM(RCX, R12, 1), RCX) \ + \ + KXNORW(K(0), K(0), K(1)) \ + KXNORW(K(0), K(0), K(2)) \ + KXNORW(K(0), K(0), K(3)) \ + KXNORW(K(0), K(0), K(4)) \ + VSCATTERQPD(ZMM(R1), MEM(RCX,ZMM(2),1) MASK_K(1)) \ + ADD(IMM(1*8), RCX) \ + KXNORW(K(0), K(0), K(1)) \ + VSCATTERQPD(ZMM(R2), MEM(RCX,ZMM(2),1) MASK_K(2)) \ + ADD(IMM(1*8), RCX) \ + KXNORW(K(0), K(0), K(2)) \ + VSCATTERQPD(ZMM(R3), MEM(RCX,ZMM(2),1) MASK_K(3)) \ + ADD(IMM(1*8), RCX) \ + KXNORW(K(0), K(0), K(3)) \ + VSCATTERQPD(ZMM(R4), MEM(RCX,ZMM(2),1) MASK_K(4)) \ + ADD(IMM(1*8), RCX) \ + KXNORW(K(0), K(0), K(4)) \ + VSCATTERQPD(ZMM(R5), MEM(RCX,ZMM(2),1) MASK_K(1)) \ + ADD(IMM(1*8), RCX) \ + KXNORW(K(0), K(0), K(1)) \ + VSCATTERQPD(ZMM(R6), MEM(RCX,ZMM(2),1) MASK_K(2)) \ + \ + MOV(R8, RCX) \ + LEA(MEM(RCX, R12, 4), RCX) \ + LEA(MEM(RCX, R12, 1), RCX) \ + \ + VMOVUPD(MEM(RCX), ZMM(R1))\ + ADD(R12, RCX) \ + VMOVUPD(MEM(RCX), ZMM(R2)) \ + ADD(R12, RCX) \ + VMOVUPD(MEM(RCX), ZMM(R3)) \ + ADD(R12, RCX) \ + VMOVUPD(MEM(RCX), ZMM(R4)) \ + ADD(R12, RCX) \ + VMOVUPD(MEM(RCX), ZMM(R5)) \ + ADD(R12, RCX) \ + VMOVUPD(MEM(RCX), ZMM(R6)) \ + ADD(R12, RCX) \ + VMOVUPD(MEM(RCX), ZMM(R7)) \ + ADD(R12, RCX) \ + VMOVUPD(MEM(RCX), ZMM(R18)) \ + +// Offsets for scatter/gather instructions +static int64_t offsets[16] __attribute__((aligned(64))) = + { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,10,11,12,13,14,15}; + + +void bli_dgemmtrsm_u_zen_asm_16x14 +( + dim_t k_, + double* restrict alpha, + double* restrict a10, + double* restrict a11, + double* restrict b01, + double* restrict b11, + double* restrict c11, inc_t rs_c_, inc_t cs_c_, + auxinfo_t* restrict data, + cntx_t* restrict cntx +) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_9); + const int64_t k = k_; + uint64_t rs_c = rs_c_ * 8; + const int64_t* offsetPtr = &offsets[0]; + uint64_t cs_c = cs_c_ * 8; + + BEGIN_ASM() + + //clear out registers + VXORPD(ZMM(4), ZMM(4), ZMM(4)) + VMOVAPD(ZMM(4), ZMM(5) ) + VMOVAPD(ZMM(4), ZMM(6) ) + VMOVAPD(ZMM(4), ZMM(7) ) + VMOVAPD(ZMM(4), ZMM(8) ) + VMOVAPD(ZMM(4), ZMM(9) ) + VMOVAPD(ZMM(4), ZMM(10)) + VMOVAPD(ZMM(4), ZMM(11)) + VMOVAPD(ZMM(4), ZMM(12)) + VMOVAPD(ZMM(4), ZMM(13)) + VMOVAPD(ZMM(4), ZMM(14)) + VMOVAPD(ZMM(4), ZMM(15)) + VMOVAPD(ZMM(4), ZMM(16)) + VMOVAPD(ZMM(4), ZMM(17)) + VMOVAPD(ZMM(4), ZMM(18)) + VMOVAPD(ZMM(4), ZMM(19)) + VMOVAPD(ZMM(4), ZMM(20)) + VMOVAPD(ZMM(4), ZMM(21)) + VMOVAPD(ZMM(4), ZMM(22)) + VMOVAPD(ZMM(4), ZMM(23)) + VMOVAPD(ZMM(4), ZMM(24)) + VMOVAPD(ZMM(4), ZMM(25)) + VMOVAPD(ZMM(4), ZMM(26)) + VMOVAPD(ZMM(4), ZMM(27)) + VMOVAPD(ZMM(4), ZMM(28)) + VMOVAPD(ZMM(4), ZMM(29)) + VMOVAPD(ZMM(4), ZMM(30)) + VMOVAPD(ZMM(4), ZMM(31)) + + MOV(VAR(k), RSI) + + MOV(VAR(a10), RAX) // load address of a + MOV(VAR(b01), RBX) // load address of b + MOV(VAR(c11), R8) // load address of c + + LEA(MEM(RSI,RSI,2), RDX) + LEA(MEM(,RDX,4), RDX) + LEA(MEM(RDX,RSI,4), RDX) // RDX = 16 * K + LEA(MEM(RAX,RDX,8,-128), RDX) // RDX = a_next for prefetching + LEA(MEM(R8,63), R12) // c for prefetching + + MOV(VAR(rs_c), R9) + MOV(VAR(cs_c), R13) + + MOV(IMM(0), R11) + MOV(R13, R15) + + CMP(IMM(8), R13) + JNE(.DBEFORELOOP) + MOV(IMM(2), R11) + MOV(R9, R15) + + LABEL(.DBEFORELOOP) + + VMOVAPD(MEM(RAX, 0*8*8), ZMM(0)) + VMOVAPD(MEM(RAX, 1*8*8), ZMM(1)) // preload a + + MOV(RSI, R10) + AND(IMM(3), R10) // R10 = K % 4 + SAR(IMM(2), RSI) // RSI = K / 4 + + /* + MAIN LOOP + Note: This loop runs (K/4 - 14 - TAIL_NITER) times + */ + SUB(R11, RSI) + SUB(IMM(14+TAIL_NITER), RSI) + JLE(K_LE_80) + + LOOP_ALIGN + LABEL(LOOP1) + SUBITER(0) + PREFETCH(1, MEM(RDX)) + SUBITER(1) + SUB(IMM(1), RSI) + SUBITER(2) + PREFETCH(1, MEM(RDX,64)) + SUBITER(3) + + LEA(MEM(RAX,4*16*8), RAX) + LEA(MEM(RBX,4*14*8), RBX) + LEA(MEM(RDX,16*8), RDX) + + JNZ(LOOP1) + + LABEL(K_LE_80) + + /* + C prefetch Loop + Note: This loop runs 14 times, + These 14 iterations are done seperately so that c11 can be prefetched here. + */ + ADD(R11, RSI) + ADD(IMM(14), RSI) + JLE(K_LE_24) + + LOOP_ALIGN + LABEL(LOOP2) + PREFETCH(0, MEM(R12)) + SUBITER(0) + PREFETCH(1, MEM(RDX)) + SUBITER(1) + PREFETCH(0, MEM(R12,64)) + SUB(IMM(1), RSI) + SUBITER(2) + PREFETCH(1, MEM(RDX,64)) + SUBITER(3) + + LEA(MEM(RAX,4*16*8), RAX) + LEA(MEM(RBX,4*14*8), RBX) + LEA(MEM(RDX,16*8), RDX) + LEA(MEM(R12,R15,1), R12) + + JNZ(LOOP2) + + LABEL(K_LE_24) + + /* + TAIL_NITER Loop + Note: This loop runs TAIL_NITER times, + This loop is used to provide some distance between c11 prefetch and usage of c11. + */ + ADD(IMM(0+TAIL_NITER), RSI) + JLE(TAIL) + + LOOP_ALIGN + LABEL(LOOP3) + + SUBITER(0) + PREFETCH(1, MEM(RDX)) + SUBITER(1) + SUB(IMM(1), RSI) + SUBITER(2) + PREFETCH(1, MEM(RDX,64)) + SUBITER(3) + + LEA(MEM(RAX,4*16*8), RAX) + LEA(MEM(RBX,4*14*8), RBX) + LEA(MEM(RDX,16*8), RDX) + + JNZ(LOOP3) + + /* + K Left Loop + This loop runs K % 4 times. + */ + LABEL(TAIL) + MOV(R10, RSI) + TEST(RSI, RSI) + JE(.DPOSTACCUM) + LOOP_ALIGN + LABEL(TAIL_LOOP) + + SUB(IMM(1), RSI) + SUBITER(0) + + LEA(MEM(RAX,16*8), RAX) + LEA(MEM(RBX,14*8), RBX) + + JNZ(TAIL_LOOP) + + LABEL(.DPOSTACCUM) + + /* GEMM output before transpose GEMM output after transpose + __________________________________ + ___________________________ |______zmm4______|______zmm20___x x| + | | | | | | | | | | | | | | | |______zmm6______|______zmm22___x x| + |z|z|z|z|z|z|z|z|z|z|z|z|z|z| |______zmm8______|______zmm24___x x| + |m|m|m|m|m|m|m|m|m|m|m|m|m|m| |______zmm10_____|______zmm26___x x| + |m|m|m|m|m|m|m|m|m|m|m|m|m|m| |______zmm12_____|______zmm28___x x| + |4|6|8|1|1|1|1|1|2|2|2|2|2|3| |______zmm14_____|______zmm30___x x| + | | | |0|2|4|6|8|0|2|4|6|8|0| |______zmm16_____|_____c11______x x| + | | | | | | | | | | | | | | | |______zmm18_____|_____c11+cs___x x| + ____________________________ |______zmm5______|______zmm21___x x| + | | | | | | | | | | | | | | | |______zmm7______|______zmm23___x x| + |z|z|z|z|z|z|z|z|z|z|z|z|z|z| |______zmm9______|______zmm25___x x| + |m|m|m|m|m|m|m|m|m|m|m|m|m|m| |______zmm11_____|______zmm27___x x| + |m|m|m|m|m|m|m|m|m|m|m|m|m|m| |______zmm13_____|______zmm29___x x| + |5|7|9|1|1|1|1|1|2|2|2|2|2|3| |______zmm15_____|______zmm31___x x| + | | | |1|3|5|7|9|1|3|5|7|9|1| |______zmm17_____|____c11+cs*2__x x| + | | | | | | | | | | | | | | | |______zmm19_____|____c11+cs*4__x x| + _____________________________ + */ + TRANSPOSE_REGISTERS_8x8(4, 6, 8, 10, 12, 14, 16, 18) // transpose the output of GEMM + TRANSPOSE_REGISTERS_8x8(5, 7, 9, 11, 13, 15, 17, 19) + TRANSPOSE_REGISTERS_6x8(20, 22, 24, 26, 28, 30, 0, 1) + VMOVUPD(ZMM(0), MEM(R8 )) + VMOVUPD(ZMM(1), MEM(R8, R12, 1)) // zmm0 and zmm1 are needed for other computations, + // therefore store zmm0, zmm1 's data in rcx + TRANSPOSE_REGISTERS_6x8(21, 23, 25, 27, 29, 31, 0, 1) + VMOVUPD(ZMM(0), MEM(R8, R12, 2)) + VMOVUPD(ZMM(1), MEM(R8, R12, 4)) // zmm0 and zmm1 are needed for other computations, + // therefore store zmm0, zmm1 's data in rcx + MOV(IMM(14), RDI) + LEA(MEM(, RDI, 8), RDI) + + MOV(VAR(alpha), RBX) + VBROADCASTSD(MEM(RBX), ZMM(3)) + + MOV(IMM(1), RSI) + LEA(MEM(, RSI, 8), RSI) + + MOV(VAR(b11), RCX) + LEA(MEM(RCX, RSI, 8), RDX) + + MOV(RCX, R11) + MOV(RDX, R14) + + // Subtract b11 from GEMM output + VFMSUB231PD(MEM(RCX), ZMM(3), ZMM(4)) + ADD(RDI, RCX) + VFMSUB231PD(MEM(RCX), ZMM(3), ZMM(6)) + ADD(RDI, RCX) + VFMSUB231PD(MEM(RCX), ZMM(3), ZMM(8)) + ADD(RDI, RCX) + VFMSUB231PD(MEM(RCX), ZMM(3), ZMM(10)) + ADD(RDI, RCX) + VFMSUB231PD(MEM(RCX), ZMM(3), ZMM(12)) + ADD(RDI, RCX) + VFMSUB231PD(MEM(RCX), ZMM(3), ZMM(14)) + ADD(RDI, RCX) + VFMSUB231PD(MEM(RCX), ZMM(3), ZMM(16)) + ADD(RDI, RCX) + VFMSUB231PD(MEM(RCX), ZMM(3), ZMM(18)) + ADD(RDI, RCX) + + VFMSUB231PD(MEM(RCX), ZMM(3), ZMM(5)) + ADD(RDI, RCX) + VFMSUB231PD(MEM(RCX), ZMM(3), ZMM(7)) + ADD(RDI, RCX) + VFMSUB231PD(MEM(RCX), ZMM(3), ZMM(9)) + ADD(RDI, RCX) + VFMSUB231PD(MEM(RCX), ZMM(3), ZMM(11)) + ADD(RDI, RCX) + VFMSUB231PD(MEM(RCX), ZMM(3), ZMM(13)) + ADD(RDI, RCX) + VFMSUB231PD(MEM(RCX), ZMM(3), ZMM(15)) + ADD(RDI, RCX) + VFMSUB231PD(MEM(RCX), ZMM(3), ZMM(17)) + ADD(RDI, RCX) + VFMSUB231PD(MEM(RCX), ZMM(3), ZMM(19)) + + + + VFMSUB231PD(MEM(RDX), ZMM(3), ZMM(20)) + ADD(RDI, RDX) + VFMSUB231PD(MEM(RDX), ZMM(3), ZMM(22)) + ADD(RDI, RDX) + VFMSUB231PD(MEM(RDX), ZMM(3), ZMM(24)) + ADD(RDI, RDX) + VFMSUB231PD(MEM(RDX), ZMM(3), ZMM(26)) + ADD(RDI, RDX) + VFMSUB231PD(MEM(RDX), ZMM(3), ZMM(28)) + ADD(RDI, RDX) + VFMSUB231PD(MEM(RDX), ZMM(3), ZMM(30)) + ADD(RDI, RDX) + VMOVUPD(MEM(R8 ), ZMM(0)) + VMOVUPD(MEM(R8, R12, 1), ZMM(1)) + VFMSUB231PD(MEM(RDX), ZMM(3), ZMM(0)) + VMOVUPD(ZMM(0), MEM(R8 )) + ADD(RDI, RDX) + VFMSUB231PD(MEM(RDX), ZMM(3), ZMM(1)) + VMOVUPD(ZMM(1), MEM(R8, R12, 1)) + ADD(RDI, RDX) + + VFMSUB231PD(MEM(RDX), ZMM(3), ZMM(21)) + ADD(RDI, RDX) + VFMSUB231PD(MEM(RDX), ZMM(3), ZMM(23)) + ADD(RDI, RDX) + VFMSUB231PD(MEM(RDX), ZMM(3), ZMM(25)) + ADD(RDI, RDX) + VFMSUB231PD(MEM(RDX), ZMM(3), ZMM(27)) + ADD(RDI, RDX) + VFMSUB231PD(MEM(RDX), ZMM(3), ZMM(29)) + ADD(RDI, RDX) + VFMSUB231PD(MEM(RDX), ZMM(3), ZMM(31)) + ADD(RDI, RDX) + VMOVUPD(MEM(R8, R12, 2), ZMM(0)) + VMOVUPD(MEM(R8, R12, 4), ZMM(1)) + VFMSUB231PD(MEM(RDX), ZMM(3), ZMM(0)) + VMOVUPD(ZMM(0), MEM(R8, R12, 2)) + ADD(RDI, RDX) + VFMSUB231PD(MEM(RDX), ZMM(3), ZMM(1)) + VMOVUPD(ZMM(1), MEM(R8, R12, 4)) + + /* + TRSM region + Each row requires 1 iteration, therefore 16 iterations are present + */ + MOV(VAR(a11), RAX) + MOV(R11, RCX) + MOV(R14, RDX) + + LEA(MEM(RDI, RDI, 4), R14) + LEA(MEM(R14, R14, 2), R14) // R14 = RDI * 15 + LEA(MEM(RCX, R14, 1), RCX) // rcx = b11 + (16-1)*rs_b + LEA(MEM(RDX, R14, 1), RDX) // rdx = b11 + (16-1)*rs_b + 8*cs_b + + //iteration 0 -------------------------------------------- + VBROADCASTSD(MEM(RAX, (15+15*16)*8), ZMM(0)) + + VMOVUPD(MEM(R8, R12, 4), ZMM(1)) + #ifdef BLIS_ENABLE_TRSM_PREINVERSION + VMULPD(ZMM(0), ZMM(19), ZMM(19)) + VMULPD(ZMM(0), ZMM(1) , ZMM(1) ) + #else + VDIVPD(ZMM(0), ZMM(19),ZMM(19) ) + VDIVPD(ZMM(0), ZMM(1) , ZMM(1) ) + #endif + VMOVUPD(ZMM(1), MEM(R8, R12, 4)) + + VMOVUPD(ZMM(19), MEM(RCX)) + VEXTRACTF64X4(IMM(1), ZMM(1), YMM(0)) + VMOVUPD(YMM(1), MEM(RDX )) + VMOVUPD(XMM(0), MEM(RDX,4*8)) // move only first six values to rcx + SUB(RDI, RCX) + SUB(RDI, RDX) + + //iteration 1 -------------------------------------------- + VBROADCASTSD(MEM(RAX, (14+15*16)*8), ZMM(0)) + VMOVUPD(MEM(R8, R12, 4), ZMM(1)) + + VMULPD(ZMM(0), ZMM(19), ZMM(2)) + VMULPD(ZMM(0), ZMM(1) , ZMM(3)) + + VBROADCASTSD(MEM(RAX, (14+14*16)*8), ZMM(1)) + VMOVUPD(MEM(R8, R12, 2), ZMM(0)) + + VSUBPD(ZMM(2), ZMM(17), ZMM(17)) + VSUBPD(ZMM(3), ZMM(0) , ZMM(0) ) + + #ifdef BLIS_ENABLE_TRSM_PREINVERSION + VMULPD(ZMM(1), ZMM(17), ZMM(17)) + VMULPD(ZMM(1), ZMM(0) , ZMM(0) ) + #else + VDIVPD(ZMM(1), ZMM(17), ZMM(17)) + VDIVPD(ZMM(1), ZMM(0) , ZMM(0) ) + #endif + + VMOVUPD(ZMM(0), MEM(R8, R12, 2)) + + VMOVUPD(ZMM(17), MEM(RCX)) + + VEXTRACTF64X4(IMM(1), ZMM(0), YMM(1)) + VMOVUPD(YMM(0), MEM(RDX )) + VMOVUPD(XMM(1), MEM(RDX,4*8)) + SUB(RDI, RCX) + SUB(RDI, RDX) + + //iteration 2 -------------------------------------------- + VBROADCASTSD(MEM(RAX, (13+15*16)*8), ZMM(0)) + VMOVUPD(MEM(R8, R12, 4), ZMM(1)) + + VMULPD(ZMM(0), ZMM(19), ZMM(2)) + VMULPD(ZMM(0), ZMM(1) , ZMM(3)) + + + VBROADCASTSD(MEM(RAX, (13+14*16)*8), ZMM(1)) + VMOVUPD(MEM(R8, R12, 2), ZMM(0)) + + VFMADD231PD(ZMM(1), ZMM(17), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(0) , ZMM(3)) + + VBROADCASTSD(MEM(RAX, (13+13*16)*8), ZMM(0)) + VSUBPD(ZMM(2), ZMM(15), ZMM(15)) + VSUBPD(ZMM(3), ZMM(31), ZMM(31)) + + #ifdef BLIS_ENABLE_TRSM_PREINVERSION + VMULPD(ZMM(0), ZMM(15), ZMM(15)) + VMULPD(ZMM(0), ZMM(31), ZMM(31)) + #else + VDIVPD(ZMM(0), ZMM(15), ZMM(15)) + VDIVPD(ZMM(0), ZMM(31), ZMM(31)) + #endif + + VMOVUPD(ZMM(15), MEM(RCX)) + VEXTRACTF64X4(IMM(1), ZMM(31), YMM(0)) + VMOVUPD(YMM(31), MEM(RDX )) + VMOVUPD(XMM(0) , MEM(RDX,4*8)) + SUB(RDI, RCX) + SUB(RDI, RDX) + + //iteration 3 -------------------------------------------- + VBROADCASTSD(MEM(RAX, (12+15*16)*8), ZMM(0)) + VMOVUPD(MEM(R8, R12, 4), ZMM(1)) + + VMULPD(ZMM(0), ZMM(19), ZMM(2)) + VMULPD(ZMM(0), ZMM(1) , ZMM(3)) + + VBROADCASTSD(MEM(RAX, (12+14*16)*8), ZMM(1)) + VMOVUPD(MEM(R8, R12, 2), ZMM(0)) + + VFMADD231PD(ZMM(1), ZMM(17), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(0) , ZMM(3)) + + VBROADCASTSD(MEM(RAX, (12+13*16)*8), ZMM(0)) + VBROADCASTSD(MEM(RAX, (12+12*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(15), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(31), ZMM(3)) + + VSUBPD(ZMM(2), ZMM(13), ZMM(13)) + VSUBPD(ZMM(3), ZMM(29), ZMM(29)) + + #ifdef BLIS_ENABLE_TRSM_PREINVERSION + VMULPD(ZMM(1), ZMM(13), ZMM(13)) + VMULPD(ZMM(1), ZMM(29), ZMM(29)) + #else + VDIVPD(ZMM(1), ZMM(13), ZMM(13)) + VDIVPD(ZMM(1), ZMM(29), ZMM(29)) + #endif + + VMOVUPD(ZMM(13), MEM(RCX)) + VEXTRACTF64X4(IMM(1), ZMM(29), YMM(0)) + VMOVUPD(YMM(29), MEM(RDX )) + VMOVUPD(XMM(0) , MEM(RDX,4*8)) + SUB(RDI, RCX) + SUB(RDI, RDX) + + //iteration 4 -------------------------------------------- + VBROADCASTSD(MEM(RAX, (11+15*16)*8), ZMM(0)) + VMOVUPD(MEM(R8, R12, 4), ZMM(1)) + + VMULPD(ZMM(0), ZMM(19), ZMM(2)) + VMULPD(ZMM(0), ZMM(1) , ZMM(3)) + + VBROADCASTSD(MEM(RAX, (11+14*16)*8), ZMM(1)) + VMOVUPD(MEM(R8, R12, 2), ZMM(0)) + + VFMADD231PD(ZMM(1), ZMM(17), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(0) , ZMM(3)) + + VBROADCASTSD(MEM(RAX, (11+13*16)*8), ZMM(0)) + VBROADCASTSD(MEM(RAX, (11+12*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(15), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(31), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (11+11*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(13), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(29), ZMM(3)) + + VSUBPD(ZMM(2), ZMM(11), ZMM(11)) + VSUBPD(ZMM(3), ZMM(27), ZMM(27)) + + #ifdef BLIS_ENABLE_TRSM_PREINVERSION + VMULPD(ZMM(0), ZMM(11), ZMM(11)) + VMULPD(ZMM(0), ZMM(27), ZMM(27)) + #else + VDIVPD(ZMM(0), ZMM(11), ZMM(11)) + VDIVPD(ZMM(0), ZMM(27), ZMM(27)) + #endif + + VMOVUPD(ZMM(11), MEM(RCX)) + VEXTRACTF64X4(IMM(1), ZMM(27), YMM(0)) + VMOVUPD(YMM(27), MEM(RDX )) + VMOVUPD(XMM(0) , MEM(RDX,4*8)) + SUB(RDI, RCX) + SUB(RDI, RDX) + + //iteration 5 -------------------------------------------- + VBROADCASTSD(MEM(RAX, (10+15*16)*8), ZMM(0)) + VMOVUPD(MEM(R8, R12, 4), ZMM(1)) + + VMULPD(ZMM(0), ZMM(19), ZMM(2)) + VMULPD(ZMM(0), ZMM(1) , ZMM(3)) + + VBROADCASTSD(MEM(RAX, (10+14*16)*8), ZMM(1)) + VMOVUPD(MEM(R8, R12, 2), ZMM(0)) + + VFMADD231PD(ZMM(1), ZMM(17), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(0) , ZMM(3)) + + VBROADCASTSD(MEM(RAX, (10+13*16)*8), ZMM(0)) + VBROADCASTSD(MEM(RAX, (10+12*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(15), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(31), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (10+11*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(13), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(29), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (10+10*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(11), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(27), ZMM(3)) + + VSUBPD(ZMM(2), ZMM(9) , ZMM(9) ) + VSUBPD(ZMM(3), ZMM(25), ZMM(25)) + + #ifdef BLIS_ENABLE_TRSM_PREINVERSION + VMULPD(ZMM(1), ZMM(9) , ZMM(9) ) + VMULPD(ZMM(1), ZMM(25), ZMM(25)) + #else + VDIVPD(ZMM(1), ZMM(9) , ZMM(9) ) + VDIVPD(ZMM(1), ZMM(25), ZMM(25)) + #endif + + VMOVUPD(ZMM(9) , MEM(RCX)) + VEXTRACTF64X4(IMM(1), ZMM(25), YMM(0)) + VMOVUPD(YMM(25), MEM(RDX )) + VMOVUPD(XMM(0) , MEM(RDX,4*8)) + SUB(RDI, RCX) + SUB(RDI, RDX) + + //iteration 6 -------------------------------------------- + VBROADCASTSD(MEM(RAX, (9+15*16)*8), ZMM(0)) + VMOVUPD(MEM(R8, R12, 4), ZMM(1)) + + VMULPD(ZMM(0), ZMM(19), ZMM(2)) + VMULPD(ZMM(0), ZMM(1) , ZMM(3)) + + VBROADCASTSD(MEM(RAX, (9+14*16)*8), ZMM(1)) + VMOVUPD(MEM(R8, R12, 2), ZMM(0)) + + VFMADD231PD(ZMM(1), ZMM(17), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(0) , ZMM(3)) + + VBROADCASTSD(MEM(RAX, (9+13*16)*8), ZMM(0)) + VBROADCASTSD(MEM(RAX, (9+12*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(15), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(31), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (9+11*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(13), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(29), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (9+10*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(11), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(27), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (9+9*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(9) , ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(25), ZMM(3)) + + VSUBPD(ZMM(2), ZMM(7) , ZMM(7) ) + VSUBPD(ZMM(3), ZMM(23), ZMM(23)) + + #ifdef BLIS_ENABLE_TRSM_PREINVERSION + VMULPD(ZMM(0), ZMM(7) , ZMM(7) ) + VMULPD(ZMM(0), ZMM(23), ZMM(23)) + #else + VDIVPD(ZMM(0), ZMM(7) , ZMM(7) ) + VDIVPD(ZMM(0), ZMM(23), ZMM(23)) + #endif + + VMOVUPD(ZMM(7), MEM(RCX)) + VEXTRACTF64X4(IMM(1), ZMM(23), YMM(0)) + VMOVUPD(YMM(23), MEM(RDX )) + VMOVUPD(XMM(0) , MEM(RDX,4*8)) + SUB(RDI, RCX) + SUB(RDI, RDX) + + //iteration 7 -------------------------------------------- + VBROADCASTSD(MEM(RAX, (8+15*16)*8), ZMM(0)) + VMOVUPD(MEM(R8, R12, 4), ZMM(1)) + + VMULPD(ZMM(0), ZMM(19), ZMM(2)) + VMULPD(ZMM(0), ZMM(1) , ZMM(3)) + + VBROADCASTSD(MEM(RAX, (8+14*16)*8), ZMM(1)) + VMOVUPD(MEM(R8, R12, 2), ZMM(0)) + + VFMADD231PD(ZMM(1), ZMM(17), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(0) , ZMM(3)) + + VBROADCASTSD(MEM(RAX, (8+13*16)*8), ZMM(0)) + VBROADCASTSD(MEM(RAX, (8+12*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(15), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(31), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (8+11*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(13), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(29), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (8+10*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(11), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(27), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (8+9*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(9) , ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(25), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (8+8*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(7) , ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(23), ZMM(3)) + + VSUBPD(ZMM(2), ZMM(5) , ZMM(5) ) + VSUBPD(ZMM(3), ZMM(21), ZMM(21)) + + #ifdef BLIS_ENABLE_TRSM_PREINVERSION + VMULPD(ZMM(1), ZMM(5) , ZMM(5) ) + VMULPD(ZMM(1), ZMM(21), ZMM(21)) + #else + VDIVPD(ZMM(1), ZMM(5) , ZMM(5) ) + VDIVPD(ZMM(1), ZMM(21), ZMM(21)) + #endif + + VMOVUPD(ZMM(5) , MEM(RCX)) + VEXTRACTF64X4(IMM(1), ZMM(21), YMM(1)) + VMOVUPD(YMM(21), MEM(RDX )) + VMOVUPD(XMM(1), MEM(RDX,4*8)) + SUB(RDI, RCX) + SUB(RDI, RDX) + + //iteration 8 -------------------------------------------- + VBROADCASTSD(MEM(RAX, (7+15*16)*8), ZMM(0)) + VMOVUPD(MEM(R8, R12, 4), ZMM(1)) + + VMULPD(ZMM(0), ZMM(19), ZMM(2)) + VMULPD(ZMM(0), ZMM(1) , ZMM(3)) + + VBROADCASTSD(MEM(RAX, (7+14*16)*8), ZMM(1)) + VMOVUPD(MEM(R8, R12, 2), ZMM(0)) + + VFMADD231PD(ZMM(1), ZMM(17), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(0) , ZMM(3)) + + VBROADCASTSD(MEM(RAX, (7+13*16)*8), ZMM(0)) + VBROADCASTSD(MEM(RAX, (7+12*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(15), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(31), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (7+11*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(13), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(29), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (7+10*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(11), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(27), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (7+9*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(9) , ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(25), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (7+8*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(7) , ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(23), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (7+7*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(5), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(21) , ZMM(3)) + + VMOVUPD(MEM(R8, R12, 1), ZMM(1)) + VSUBPD(ZMM(2), ZMM(18), ZMM(18)) + VSUBPD(ZMM(3), ZMM(1) , ZMM(1) ) + + #ifdef BLIS_ENABLE_TRSM_PREINVERSION + VMULPD(ZMM(0), ZMM(18), ZMM(18)) + VMULPD(ZMM(0), ZMM(1) , ZMM(1) ) + #else + VDIVPD(ZMM(0), ZMM(18), ZMM(18)) + VDIVPD(ZMM(0), ZMM(1) , ZMM(1) ) + #endif + VMOVUPD(ZMM(1), MEM(R8, R12, 1)) + + VMOVUPD(ZMM(18), MEM(RCX)) + VEXTRACTF64X4(IMM(1), ZMM(1), YMM(0)) + VMOVUPD(YMM(1), MEM(RDX )) + VMOVUPD(XMM(0), MEM(RDX,4*8)) + SUB(RDI, RCX) + SUB(RDI, RDX) + + //iteration 9 -------------------------------------------- + VBROADCASTSD(MEM(RAX, (6+15*16)*8), ZMM(0)) + VMOVUPD(MEM(R8, R12, 4), ZMM(1)) + + VMULPD(ZMM(0), ZMM(19), ZMM(2)) + VMULPD(ZMM(0), ZMM(1) , ZMM(3)) + + VBROADCASTSD(MEM(RAX, (6+14*16)*8), ZMM(1)) + VMOVUPD(MEM(R8, R12, 2), ZMM(0)) + + VFMADD231PD(ZMM(1), ZMM(17), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(0) , ZMM(3)) + + VBROADCASTSD(MEM(RAX, (6+13*16)*8), ZMM(0)) + VBROADCASTSD(MEM(RAX, (6+12*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(15), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(31), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (6+11*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(13), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(29), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (6+10*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(11), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(27), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (6+9*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(9) , ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(25), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (6+8*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(7) , ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(23), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (6+7*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(5) , ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(21), ZMM(3)) + + VMOVUPD(MEM(R8, R12, 1), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(18), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(1) , ZMM(3)) + VBROADCASTSD(MEM(RAX, (6+6*16)*8), ZMM(1)) + + VMOVUPD(MEM(R8 ), ZMM(0)) + VSUBPD(ZMM(2), ZMM(16), ZMM(16)) + VSUBPD(ZMM(3), ZMM(0) , ZMM(0) ) + + #ifdef BLIS_ENABLE_TRSM_PREINVERSION + VMULPD(ZMM(1), ZMM(16), ZMM(16)) + VMULPD(ZMM(1), ZMM(0) , ZMM(0) ) + #else + VDIVPD(ZMM(1), ZMM(16), ZMM(16)) + VDIVPD(ZMM(1), ZMM(0) , ZMM(0) ) + #endif + + VMOVUPD(ZMM(0), MEM(R8 )) + VMOVUPD(ZMM(16), MEM(RCX)) + VEXTRACTF64X4(IMM(1), ZMM(0), YMM(1)) + VMOVUPD(YMM(0), MEM(RDX )) + VMOVUPD(XMM(1), MEM(RDX,4*8)) + SUB(RDI, RCX) + SUB(RDI, RDX) + + //iteration 10 -------------------------------------------- + VBROADCASTSD(MEM(RAX, (5+15*16)*8), ZMM(0)) + VMOVUPD(MEM(R8, R12, 4), ZMM(1)) + + VMULPD(ZMM(0), ZMM(19), ZMM(2)) + VMULPD(ZMM(0), ZMM(1) , ZMM(3)) + + VBROADCASTSD(MEM(RAX, (5+14*16)*8), ZMM(1)) + VMOVUPD(MEM(R8, R12, 2), ZMM(0)) + + VFMADD231PD(ZMM(1), ZMM(17), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(0) , ZMM(3)) + + VBROADCASTSD(MEM(RAX, (5+13*16)*8), ZMM(0)) + VBROADCASTSD(MEM(RAX, (5+12*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(15), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(31), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (5+11*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(13), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(29), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (5+10*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(11), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(27), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (5+9*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(9) , ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(25), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (5+8*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(7) , ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(23), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (5+7*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(5) , ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(21), ZMM(3)) + + VMOVUPD(MEM(R8, R12, 1), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(18), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(1) , ZMM(3)) + VBROADCASTSD(MEM(RAX, (5+6*16)*8), ZMM(1)) + + VMOVUPD(MEM(R8 ), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(16), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(0) , ZMM(3)) + VBROADCASTSD(MEM(RAX, (5+5*16)*8), ZMM(0)) + + VSUBPD(ZMM(2), ZMM(14), ZMM(14)) + VSUBPD(ZMM(3), ZMM(30), ZMM(30)) + + #ifdef BLIS_ENABLE_TRSM_PREINVERSION + VMULPD(ZMM(0), ZMM(14), ZMM(14)) + VMULPD(ZMM(0), ZMM(30), ZMM(30)) + #else + VDIVPD(ZMM(0), ZMM(14), ZMM(14)) + VDIVPD(ZMM(0), ZMM(30), ZMM(30)) + #endif + + VMOVUPD(ZMM(14), MEM(RCX)) + VEXTRACTF64X4(IMM(1), ZMM(30), YMM(1)) + VMOVUPD(YMM(30), MEM(RDX )) + VMOVUPD(XMM(1), MEM(RDX,4*8)) + SUB(RDI, RCX) + SUB(RDI, RDX) + + //iteration 11 -------------------------------------------- + VBROADCASTSD(MEM(RAX, (4+15*16)*8), ZMM(0)) + VMOVUPD(MEM(R8, R12, 4), ZMM(1)) + + VMULPD(ZMM(0), ZMM(19), ZMM(2)) + VMULPD(ZMM(0), ZMM(1) , ZMM(3)) + + VBROADCASTSD(MEM(RAX, (4+14*16)*8), ZMM(1)) + VMOVUPD(MEM(R8, R12, 2), ZMM(0)) + + VFMADD231PD(ZMM(1), ZMM(17), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(0) , ZMM(3)) + + VBROADCASTSD(MEM(RAX, (4+13*16)*8), ZMM(0)) + VBROADCASTSD(MEM(RAX, (4+12*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(15), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(31), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (4+11*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(13), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(29), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (4+10*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(11), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(27), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (4+9*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(9) , ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(25), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (4+8*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(7) , ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(23), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (4+7*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(5) , ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(21), ZMM(3)) + + VMOVUPD(MEM(R8, R12, 1), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(18), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(1) , ZMM(3)) + VBROADCASTSD(MEM(RAX, (4+6*16)*8), ZMM(1)) + + VMOVUPD(MEM(R8 ), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(16), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(0) , ZMM(3)) + VBROADCASTSD(MEM(RAX, (4+5*16)*8), ZMM(0)) + + VBROADCASTSD(MEM(RAX, (4+4*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(14), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(30), ZMM(3)) + + VSUBPD(ZMM(2), ZMM(12), ZMM(12)) + VSUBPD(ZMM(3), ZMM(28), ZMM(28)) + + #ifdef BLIS_ENABLE_TRSM_PREINVERSION + VMULPD(ZMM(1), ZMM(12), ZMM(12)) + VMULPD(ZMM(1), ZMM(28), ZMM(28)) + #else + VDIVPD(ZMM(1), ZMM(12), ZMM(12)) + VDIVPD(ZMM(1), ZMM(28), ZMM(28)) + #endif + + VMOVUPD(ZMM(12), MEM(RCX)) + VEXTRACTF64X4(IMM(1), ZMM(28), YMM(1)) + VMOVUPD(YMM(28), MEM(RDX )) + VMOVUPD(XMM(1), MEM(RDX,4*8)) + SUB(RDI, RCX) + SUB(RDI, RDX) + + //iteration 12 -------------------------------------------- + VBROADCASTSD(MEM(RAX, (3+15*16)*8), ZMM(0)) + VMOVUPD(MEM(R8, R12, 4), ZMM(1)) + + VMULPD(ZMM(0), ZMM(19), ZMM(2)) + VMULPD(ZMM(0), ZMM(1) , ZMM(3)) + + VBROADCASTSD(MEM(RAX, (3+14*16)*8), ZMM(1)) + VMOVUPD(MEM(R8, R12, 2), ZMM(0)) + + VFMADD231PD(ZMM(1), ZMM(17), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(0) , ZMM(3)) + + VBROADCASTSD(MEM(RAX, (3+13*16)*8), ZMM(0)) + VBROADCASTSD(MEM(RAX, (3+12*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(15), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(31), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (3+11*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(13), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(29), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (3+10*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(11), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(27), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (3+9*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(9) , ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(25), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (3+8*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(7) , ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(23), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (3+7*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(5) , ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(21), ZMM(3)) + + VMOVUPD(MEM(R8, R12, 1), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(18), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(1) , ZMM(3)) + VBROADCASTSD(MEM(RAX, (3+6*16)*8), ZMM(1)) + + VMOVUPD(MEM(R8 ), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(16), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(0) , ZMM(3)) + VBROADCASTSD(MEM(RAX, (3+5*16)*8), ZMM(0)) + + VBROADCASTSD(MEM(RAX, (3+4*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(14), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(30), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (3+3*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(12), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(28), ZMM(3)) + + VSUBPD(ZMM(2), ZMM(10), ZMM(10)) + VSUBPD(ZMM(3), ZMM(26), ZMM(26)) + + #ifdef BLIS_ENABLE_TRSM_PREINVERSION + VMULPD(ZMM(0), ZMM(10), ZMM(10)) + VMULPD(ZMM(0), ZMM(26), ZMM(26)) + #else + VDIVPD(ZMM(0), ZMM(10), ZMM(10)) + VDIVPD(ZMM(0), ZMM(26), ZMM(26)) + #endif + + VMOVUPD(ZMM(10), MEM(RCX)) + VEXTRACTF64X4(IMM(1), ZMM(26), YMM(1)) + VMOVUPD(YMM(26), MEM(RDX )) + VMOVUPD(XMM(1), MEM(RDX,4*8)) + SUB(RDI, RCX) + SUB(RDI, RDX) + + //iteration 13 -------------------------------------------- + VBROADCASTSD(MEM(RAX, (2+15*16)*8), ZMM(0)) + VMOVUPD(MEM(R8, R12, 4), ZMM(1)) + + VMULPD(ZMM(0), ZMM(19), ZMM(2)) + VMULPD(ZMM(0), ZMM(1) , ZMM(3)) + + VBROADCASTSD(MEM(RAX, (2+14*16)*8), ZMM(1)) + VMOVUPD(MEM(R8, R12, 2), ZMM(0)) + + VFMADD231PD(ZMM(1), ZMM(17), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(0) , ZMM(3)) + + VBROADCASTSD(MEM(RAX, (2+13*16)*8), ZMM(0)) + VBROADCASTSD(MEM(RAX, (2+12*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(15), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(31), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (2+11*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(13), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(29), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (2+10*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(11), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(27), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (2+9*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(9) , ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(25), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (2+8*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(7) , ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(23), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (2+7*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(5) , ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(21), ZMM(3)) + + VMOVUPD(MEM(R8, R12, 1), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(18), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(1) , ZMM(3)) + VBROADCASTSD(MEM(RAX, (2+6*16)*8), ZMM(1)) + + VMOVUPD(MEM(R8 ), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(16), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(0) , ZMM(3)) + VBROADCASTSD(MEM(RAX, (2+5*16)*8), ZMM(0)) + + VBROADCASTSD(MEM(RAX, (2+4*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(14), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(30), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (2+3*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(12), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(28), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (2+2*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(10), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(26), ZMM(3)) + + VSUBPD(ZMM(2), ZMM(8), ZMM(8)) + VSUBPD(ZMM(3), ZMM(24), ZMM(24)) + + #ifdef BLIS_ENABLE_TRSM_PREINVERSION + VMULPD(ZMM(1), ZMM(8), ZMM(8)) + VMULPD(ZMM(1), ZMM(24), ZMM(24)) + #else + VDIVPD(ZMM(1), ZMM(8), ZMM(8)) + VDIVPD(ZMM(1), ZMM(24), ZMM(24)) + #endif + + VMOVUPD(ZMM(8), MEM(RCX)) + VEXTRACTF64X4(IMM(1), ZMM(24), YMM(1)) + VMOVUPD(YMM(24), MEM(RDX )) + VMOVUPD(XMM(1), MEM(RDX,4*8)) + SUB(RDI, RCX) + SUB(RDI, RDX) + + //iteration 14 -------------------------------------------- + VBROADCASTSD(MEM(RAX, (1+15*16)*8), ZMM(0)) + VMOVUPD(MEM(R8, R12, 4), ZMM(1)) + + VMULPD(ZMM(0), ZMM(19), ZMM(2)) + VMULPD(ZMM(0), ZMM(1) , ZMM(3)) + + VBROADCASTSD(MEM(RAX, (1+14*16)*8), ZMM(1)) + VMOVUPD(MEM(R8, R12, 2), ZMM(0)) + + VFMADD231PD(ZMM(1), ZMM(17), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(0) , ZMM(3)) + + VBROADCASTSD(MEM(RAX, (1+13*16)*8), ZMM(0)) + VBROADCASTSD(MEM(RAX, (1+12*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(15), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(31), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (1+11*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(13), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(29), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (1+10*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(11), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(27), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (1+9*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(9) , ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(25), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (1+8*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(7) , ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(23), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (1+7*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(5) , ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(21), ZMM(3)) + + VMOVUPD(MEM(R8, R12, 1), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(18), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(1) , ZMM(3)) + VBROADCASTSD(MEM(RAX, (1+6*16)*8), ZMM(1)) + + VMOVUPD(MEM(R8 ), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(16), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(0) , ZMM(3)) + VBROADCASTSD(MEM(RAX, (1+5*16)*8), ZMM(0)) + + VBROADCASTSD(MEM(RAX, (1+4*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(14), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(30), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (1+3*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(12), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(28), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (1+2*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(10), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(26), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (1+1*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(8) , ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(24), ZMM(3)) + + VSUBPD(ZMM(2), ZMM(6) , ZMM(6) ) + VSUBPD(ZMM(3), ZMM(22), ZMM(22)) + + #ifdef BLIS_ENABLE_TRSM_PREINVERSION + VMULPD(ZMM(0), ZMM(6) , ZMM(6) ) + VMULPD(ZMM(0), ZMM(22), ZMM(22)) + #else + VDIVPD(ZMM(0), ZMM(6) , ZMM(6) ) + VDIVPD(ZMM(0), ZMM(22), ZMM(22)) + #endif + + VMOVUPD(ZMM(6), MEM(RCX)) + VEXTRACTF64X4(IMM(1), ZMM(22), YMM(0)) + VMOVUPD(YMM(22), MEM(RDX )) + VMOVUPD(XMM(0), MEM(RDX,4*8)) + SUB(RDI, RCX) + SUB(RDI, RDX) + + //iteration 15 -------------------------------------------- + VBROADCASTSD(MEM(RAX, (0+15*16)*8), ZMM(0)) + VMOVUPD(MEM(R8, R12, 4), ZMM(1)) + + VMULPD(ZMM(0), ZMM(19), ZMM(2)) + VMULPD(ZMM(0), ZMM(1) , ZMM(3)) + + VBROADCASTSD(MEM(RAX, (0+14*16)*8), ZMM(1)) + VMOVUPD(MEM(R8, R12, 2), ZMM(0)) + + VFMADD231PD(ZMM(1), ZMM(17), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(0) , ZMM(3)) + + VBROADCASTSD(MEM(RAX, (0+13*16)*8), ZMM(0)) + VBROADCASTSD(MEM(RAX, (0+12*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(15), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(31), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (0+11*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(13), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(29), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (0+10*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(11), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(27), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (0+9*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(9) , ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(25), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (0+8*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(7) , ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(23), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (0+7*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(5) , ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(21), ZMM(3)) + + VMOVUPD(MEM(R8, R12, 1), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(18), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(1) , ZMM(3)) + VBROADCASTSD(MEM(RAX, (0+6*16)*8), ZMM(1)) + + VMOVUPD(MEM(R8 ), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(16), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(0) , ZMM(3)) + VBROADCASTSD(MEM(RAX, (0+5*16)*8), ZMM(0)) + + VBROADCASTSD(MEM(RAX, (0+4*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(14), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(30), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (0+3*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(12), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(28), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (0+2*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(10), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(26), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (0+1*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(8) , ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(24), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (0+0*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(6) , ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(22), ZMM(3)) + + VSUBPD(ZMM(2), ZMM(4) , ZMM(4) ) + VSUBPD(ZMM(3), ZMM(20), ZMM(20)) + + #ifdef BLIS_ENABLE_TRSM_PREINVERSION + VMULPD(ZMM(1), ZMM(4) , ZMM(4) ) + VMULPD(ZMM(1), ZMM(20), ZMM(20)) + #else + VDIVPD(ZMM(1), ZMM(4) , ZMM(4) ) + VDIVPD(ZMM(1), ZMM(20), ZMM(20)) + #endif + + VMOVUPD(ZMM(4), MEM(RCX)) + VEXTRACTF64X4(IMM(1), ZMM(20), YMM(1)) + VMOVUPD(YMM(20), MEM(RDX )) + VMOVUPD(XMM(1) , MEM(RDX,4*8)) + SUB(RDI, RCX) + SUB(RDI, RDX) + + /* + Storage Region (Post TRSM) + */ + MOV(R8, RCX) + MOV(R9, RDI) + MOV(VAR(cs_c), RSI) + + LEA(MEM(RCX, RSI, 8), RDX) // rdx = rcx + cs_c * 8 + LEA(MEM(RCX, RDI, 8), R14) // r14 = rcx + rs_c * 8 + + LEA(MEM(RSI, RSI, 2), R12) // cs_c * 3 + LEA(MEM(RSI, RSI, 4), R13) // cs_c * 5 + LEA(MEM(R13, RSI, 2), R15) // cs_c * 7 + + CMP(IMM(8), RSI) + JZ(.DROWSTORED) + + CMP(IMM(8), RDI) + JZ(.DCOLSTORED) + + LABEL(.DROWSTORED) + + VMOVUPD(MEM(R8 ), ZMM(1)) + VMOVUPD(ZMM(4), MEM(RCX)) + VMOVUPD(MEM(R8, RDI, 1), ZMM(4)) + + ADD(RDI, RCX) + VEXTRACTF64X4(IMM(1), ZMM(20), YMM(0)) + VMOVUPD(YMM(20), MEM(RDX )) + + VMOVUPD(MEM(R8, RDI, 2), ZMM(20)) + + VMOVUPD(XMM(0) , MEM(RDX,4*8)) + ADD(RDI, RDX) + + VMOVUPD(ZMM(6), MEM(RCX)) + + VMOVUPD(MEM(R8, RDI, 4), ZMM(6)) + + ADD(RDI, RCX) + VEXTRACTF64X4(IMM(1), ZMM(22), YMM(0)) + VMOVUPD(YMM(22), MEM(RDX )) + VMOVUPD(XMM(0) , MEM(RDX,4*8)) + ADD(RDI, RDX) + + VMOVUPD(ZMM(8), MEM(RCX)) + ADD(RDI, RCX) + VEXTRACTF64X4(IMM(1), ZMM(24), YMM(0)) + VMOVUPD(YMM(24), MEM(RDX )) + VMOVUPD(XMM(0) , MEM(RDX,4*8)) + ADD(RDI, RDX) + + VMOVUPD(ZMM(10), MEM(RCX)) + ADD(RDI, RCX) + VEXTRACTF64X4(IMM(1), ZMM(26), YMM(0)) + VMOVUPD(YMM(26), MEM(RDX )) + VMOVUPD(XMM(0) , MEM(RDX,4*8)) + ADD(RDI, RDX) + + VMOVUPD(ZMM(12), MEM(RCX)) + ADD(RDI, RCX) + VEXTRACTF64X4(IMM(1), ZMM(28), YMM(0)) + VMOVUPD(YMM(28), MEM(RDX )) + VMOVUPD(XMM(0) , MEM(RDX,4*8)) + ADD(RDI, RDX) + + VMOVUPD(ZMM(14), MEM(RCX)) + ADD(RDI, RCX) + VEXTRACTF64X4(IMM(1), ZMM(30), YMM(0)) + VMOVUPD(YMM(30), MEM(RDX )) + VMOVUPD(XMM(0) , MEM(RDX,4*8)) + ADD(RDI, RDX) + + VMOVUPD(ZMM(16), MEM(RCX)) + ADD(RDI, RCX) + VEXTRACTF64X4(IMM(1), ZMM(1), YMM(0)) + VMOVUPD(YMM(1), MEM(RDX )) + VMOVUPD(XMM(0), MEM(RDX,4*8)) + ADD(RDI, RDX) + + VMOVUPD(ZMM(18), MEM(RCX)) + ADD(RDI, RCX) + VEXTRACTF64X4(IMM(1), ZMM(4), YMM(0)) + VMOVUPD(YMM(4), MEM(RDX )) + VMOVUPD(XMM(0), MEM(RDX,4*8)) + ADD(RDI, RDX) + + VMOVUPD(ZMM(5), MEM(RCX)) + ADD(RDI, RCX) + VEXTRACTF64X4(IMM(1), ZMM(21), YMM(0)) + VMOVUPD(YMM(21), MEM(RDX )) + VMOVUPD(XMM(0) , MEM(RDX,4*8)) + ADD(RDI, RDX) + + VMOVUPD(ZMM(7), MEM(RCX)) + ADD(RDI, RCX) + VEXTRACTF64X4(IMM(1), ZMM(23), YMM(0)) + VMOVUPD(YMM(23), MEM(RDX )) + VMOVUPD(XMM(0) , MEM(RDX,4*8)) + ADD(RDI, RDX) + + VMOVUPD(ZMM(9), MEM(RCX)) + ADD(RDI, RCX) + VEXTRACTF64X4(IMM(1), ZMM(25), YMM(0)) + VMOVUPD(YMM(25), MEM(RDX )) + VMOVUPD(XMM(0) , MEM(RDX,4*8)) + ADD(RDI, RDX) + + VMOVUPD(ZMM(11), MEM(RCX)) + ADD(RDI, RCX) + VEXTRACTF64X4(IMM(1), ZMM(27), YMM(0)) + VMOVUPD(YMM(27), MEM(RDX )) + VMOVUPD(XMM(0) , MEM(RDX,4*8)) + ADD(RDI, RDX) + + VMOVUPD(ZMM(13), MEM(RCX)) + ADD(RDI, RCX) + VEXTRACTF64X4(IMM(1), ZMM(29), YMM(0)) + VMOVUPD(YMM(29), MEM(RDX )) + VMOVUPD(XMM(0) , MEM(RDX,4*8)) + ADD(RDI, RDX) + + VMOVUPD(ZMM(15), MEM(RCX)) + ADD(RDI, RCX) + VEXTRACTF64X4(IMM(1), ZMM(31), YMM(0)) + VMOVUPD(YMM(31), MEM(RDX )) + VMOVUPD(XMM(0) , MEM(RDX,4*8)) + ADD(RDI, RDX) + + VMOVUPD(ZMM(17), MEM(RCX)) + ADD(RDI, RCX) + VEXTRACTF64X4(IMM(1), ZMM(20), YMM(0)) + VMOVUPD(YMM(20), MEM(RDX )) + VMOVUPD(XMM(0), MEM(RDX,4*8)) + ADD(RDI, RDX) + + VMOVUPD(ZMM(19), MEM(RCX)) + VEXTRACTF64X4(IMM(1), ZMM(6), YMM(0)) + VMOVUPD(YMM(6), MEM(RDX )) + VMOVUPD(XMM(0), MEM(RDX,4*8)) + + + JMP(.DDONE) + LABEL(.DCOLSTORED) + + + MOV(VAR(offsetPtr), R12) + LEA(MEM(RCX, RSI, 8), RDX) // rdx = rcx + cs_c * 8 + VPBROADCASTQ(RSI, ZMM(0)) + VPMULLQ(MEM(R12 ), ZMM(0), ZMM(2)) + VPMULLQ(MEM(R12,64), ZMM(0), ZMM(3)) // load offsets in zmm2, zmm3 + + VMOVUPD(MEM(RCX ), ZMM(0)) + VMOVUPD(MEM(RCX, RSI, 1), ZMM(1)) + + MOV(RDX, RCX) + LEA(MEM(RCX, RDI, 8), R14) + UPDATE_C_COL_SCATTERED_2x6(20,21) + VMOVUPD(MEM(R8, RSI, 2), ZMM(20)) + VMOVUPD(MEM(R8, RSI, 4), ZMM(21)) + UPDATE_C_COL_SCATTERED_2x6(22,23) + UPDATE_C_COL_SCATTERED_2x6(24,25) + UPDATE_C_COL_SCATTERED_2x6(26,27) + UPDATE_C_COL_SCATTERED_2x6(28,29) + UPDATE_C_COL_SCATTERED_2x6(30,31) + UPDATE_C_COL_SCATTERED_2x6(0 ,20 ) + UPDATE_C_COL_SCATTERED_2x6(1 ,21 ) + + MOV(R8, RCX) + LEA(MEM(RCX, RDI, 8), R14) + UPDATE_C_COL_SCATTERED( 4, 5) + UPDATE_C_COL_SCATTERED( 6, 7) + UPDATE_C_COL_SCATTERED( 8, 9) + UPDATE_C_COL_SCATTERED(10,11) + UPDATE_C_COL_SCATTERED(12,13) + UPDATE_C_COL_SCATTERED(14,15) + UPDATE_C_COL_SCATTERED(16,17) + UPDATE_C_COL_SCATTERED(18,19) + + + LABEL(.DDONE) + + VZEROUPPER() + + end_asm( + : // output operands (none) + : // input operands + [a10] "m" (a10), // 1 + [k] "m" (k), // 2 + [b01] "m" (b01), // 3 + [a11] "m" (a11), // 6 + [b11] "m" (b11), // 7 + [c11] "m" (c11), // 8 + [rs_c] "m" (rs_c), // 9 + [cs_c] "m" (cs_c), // 10, + [alpha] "m" (alpha), + [offsetPtr] "m" (offsetPtr) + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rdi", "rsi", "r8", "r9", "r10", "r11", "r12", + "r13", "r14", "r15", "zmm0", "zmm1", "zmm2", "zmm3", "zmm4", "zmm5", + "zmm6", "zmm7", "zmm8", "zmm9", "zmm10", "zmm11", "zmm12", "zmm13", + "zmm14", "zmm15", "zmm16", "zmm17", "zmm18", "zmm19", "zmm20", "zmm21", + "zmm22", "zmm23", "zmm24", "zmm25", "zmm26", "zmm27", "zmm28", "zmm29", + "zmm30", "zmm31", "memory" + ) + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_9); +} \ No newline at end of file diff --git a/kernels/zen4/CMakeLists.txt b/kernels/zen4/CMakeLists.txt new file mode 100644 index 0000000000..c22c5ba143 --- /dev/null +++ b/kernels/zen4/CMakeLists.txt @@ -0,0 +1,6 @@ +##Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved.## + +add_subdirectory(1) +add_subdirectory(3) + + diff --git a/kernels/zen4/README b/kernels/zen4/README new file mode 100644 index 0000000000..c9e16c2735 --- /dev/null +++ b/kernels/zen4/README @@ -0,0 +1 @@ +Currently there are no zen4 specific kernels, however, this folder is required for the the build system. diff --git a/kernels/zen4/bli_kernels_zen4.h b/kernels/zen4/bli_kernels_zen4.h new file mode 100644 index 0000000000..e518a86047 --- /dev/null +++ b/kernels/zen4/bli_kernels_zen4.h @@ -0,0 +1,42 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +// -- level-1v -- + +// amaxv (intrinsics) +AMAXV_KER_PROT( float, s, amaxv_zen_int_avx512 ) +AMAXV_KER_PROT( double, d, amaxv_zen_int_avx512 ) + +GEMMTRSM_UKR_PROT( double, d, gemmtrsm_l_zen_asm_16x14) +GEMMTRSM_UKR_PROT( double, d, gemmtrsm_u_zen_asm_16x14) \ No newline at end of file diff --git a/ref_kernels/CMakeLists.txt b/ref_kernels/CMakeLists.txt index 61357c1fec..d26bce06a5 100644 --- a/ref_kernels/CMakeLists.txt +++ b/ref_kernels/CMakeLists.txt @@ -5,6 +5,7 @@ add_subdirectory(${CMAKE_BINARY_DIR}/ref_kernels/generic ${CMAKE_BINARY_DIR}/ref add_subdirectory(${CMAKE_BINARY_DIR}/ref_kernels/zen ${CMAKE_BINARY_DIR}/ref_kernels/zen) add_subdirectory(${CMAKE_BINARY_DIR}/ref_kernels/zen2 ${CMAKE_BINARY_DIR}/ref_kernels/zen2) add_subdirectory(${CMAKE_BINARY_DIR}/ref_kernels/zen3 ${CMAKE_BINARY_DIR}/ref_kernels/zen3) +add_subdirectory(${CMAKE_BINARY_DIR}/ref_kernels/zen4 ${CMAKE_BINARY_DIR}/ref_kernels/zen4) else() target_sources("${PROJECT_NAME}" PRIVATE diff --git a/sandbox/gemmlike/bli_gemmnat.c b/sandbox/gemmlike/bli_gemmnat.c new file mode 100644 index 0000000000..37fb701859 --- /dev/null +++ b/sandbox/gemmlike/bli_gemmnat.c @@ -0,0 +1,88 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +// Given the current architecture of BLIS sandboxes, bli_gemmnat() is the +// entry point to any sandbox implementation. + +// NOTE: This function is implemented identically to the function that it +// overrides in frame/ind/oapi/bli_l3_nat_oapi.c. This means that we are +// forgoing the option of customizing the implementations that underlie +// bli_gemm() and bli_?gemm(). Any new code defined in this sandbox +// directory, however, will be included in the BLIS. + +#include "blis.h" + +#undef GENFRONT +#define GENFRONT( opname, cname, imeth ) \ +\ +void PASTEMAC(opname,imeth) \ + ( \ + obj_t* alpha, \ + obj_t* a, \ + obj_t* b, \ + obj_t* beta, \ + obj_t* c, \ + cntx_t* cntx, \ + rntm_t* rntm \ + ) \ +{ \ +\ + /* A switch to easily toggle whether we use the sandbox implementation + of bls_gemm() as the implementation for bli_gemm(). (This allows for + easy testing of bls_gemm() via the testsuite.) */ \ + if ( 1 ) \ + { \ + bls_gemm_ex( alpha, a, b, beta, c, cntx, rntm ); \ + return; \ + } \ +\ + bli_init_once(); \ +\ + /* Obtain a valid (native) context from the gks if necessary. */ \ + if ( cntx == NULL ) cntx = bli_gks_query_cntx(); \ +\ + /* Initialize a local runtime with global settings if necessary. Note + that in the case that a runtime is passed in, we make a local copy. */ \ + rntm_t rntm_l; \ + if ( rntm == NULL ) { bli_rntm_init_from_global( &rntm_l ); rntm = &rntm_l; } \ + else { rntm_l = *rntm; rntm = &rntm_l; } \ +\ + /* Invoke the operation's front end. */ \ + PASTEMAC(opname,_front) \ + ( \ + alpha, a, b, beta, c, cntx, rntm, NULL \ + ); \ +} + +GENFRONT( gemm, gemm, nat ) diff --git a/sandbox/gemmlike/bli_sandbox.h b/sandbox/gemmlike/bli_sandbox.h new file mode 100644 index 0000000000..d6e6522e8c --- /dev/null +++ b/sandbox/gemmlike/bli_sandbox.h @@ -0,0 +1,56 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name of copyright holder(s) nor the names + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef BLIS_SANDBOX_H +#define BLIS_SANDBOX_H + +// NOTE: This header is the only header required to be present in the sandbox +// implementation directory. + +// This header should contain (or #include) any definitions that must be +// folded into blis.h. Typically, it will remain empty since any header +// definitions specific to the sandbox implementation will not need to be +// made available to applications (or the framework) during compilation. + +#include "bls_gemm.h" +#include "bls_gemm_var.h" + +#include "bls_l3_packm_a.h" +#include "bls_l3_packm_b.h" +#include "bls_l3_packm_var.h" + +#include "bls_l3_decor.h" + + +#endif diff --git a/sandbox/gemmlike/bls_gemm.c b/sandbox/gemmlike/bls_gemm.c new file mode 100644 index 0000000000..3e4c9b2a33 --- /dev/null +++ b/sandbox/gemmlike/bls_gemm.c @@ -0,0 +1,304 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +// +// -- Define the gemm-like operation's object API ------------------------------ +// + +void bls_gemm + ( + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c + ) +{ + bls_gemm_ex + ( + alpha, + a, + b, + beta, + c, + NULL, + NULL + ); +} + +void bls_gemm_ex + ( + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + rntm_t* rntm + ) +{ + bli_init_once(); + + // -- bli_gemmnat() -------------------------------------------------------- + + // Obtain a valid (native) context from the gks if necessary. + // NOTE: This must be done before calling the _check() function, since + // that function assumes the context pointer is valid. + if ( cntx == NULL ) cntx = bli_gks_query_cntx(); + + // Initialize a local runtime with global settings if necessary. Note + // that in the case that a runtime is passed in, we make a local copy. + rntm_t rntm_l; + if ( rntm == NULL ) { bli_rntm_init_from_global( &rntm_l ); rntm = &rntm_l; } + else { rntm_l = *rntm; rntm = &rntm_l; } + + // -- bli_gemm_front() ----------------------------------------------------- + + obj_t a_local; + obj_t b_local; + obj_t c_local; + + // Check parameters. + if ( bli_error_checking_is_enabled() ) + { + bli_gemm_check( alpha, a, b, beta, c, cntx ); + } + + // If C has a zero dimension, return early. + if ( bli_obj_has_zero_dim( c ) ) + { + return; + } + + // If alpha is zero, or if A or B has a zero dimension, scale C by beta + // and return early. + if ( bli_obj_equals( alpha, &BLIS_ZERO ) || + bli_obj_has_zero_dim( a ) || + bli_obj_has_zero_dim( b ) ) + { + bli_scalm( beta, c ); + return; + } + + // Alias A, B, and C in case we need to apply transformations. + bli_obj_alias_to( a, &a_local ); + bli_obj_alias_to( b, &b_local ); + bli_obj_alias_to( c, &c_local ); + + // Induce a transposition of A if it has its transposition property set. + // Then clear the transposition bit in the object. + if ( bli_obj_has_trans( &a_local ) ) + { + bli_obj_induce_trans( &a_local ); + bli_obj_set_onlytrans( BLIS_NO_TRANSPOSE, &a_local ); + } + + // Induce a transposition of B if it has its transposition property set. + // Then clear the transposition bit in the object. + if ( bli_obj_has_trans( &b_local ) ) + { + bli_obj_induce_trans( &b_local ); + bli_obj_set_onlytrans( BLIS_NO_TRANSPOSE, &b_local ); + } + + // An optimization: If C is stored by rows and the micro-kernel prefers + // contiguous columns, or if C is stored by columns and the micro-kernel + // prefers contiguous rows, transpose the entire operation to allow the + // micro-kernel to access elements of C in its preferred manner. + if ( bli_cntx_l3_vir_ukr_dislikes_storage_of( &c_local, BLIS_GEMM_UKR, cntx ) ) + { + bli_obj_swap( &a_local, &b_local ); + + bli_obj_induce_trans( &a_local ); + bli_obj_induce_trans( &b_local ); + bli_obj_induce_trans( &c_local ); + + // NOTE: This is probably not needed within the sandbox. + // We must also swap the pack schemas, which were set by bli_gemm_md() + // or the inlined code above. + //bli_obj_swap_pack_schemas( &a_local, &b_local ); + } + + // Parse and interpret the contents of the rntm_t object to properly + // set the ways of parallelism for each loop, and then make any + // additional modifications necessary for the current operation. + bli_rntm_set_ways_for_op + ( + BLIS_GEMM, + BLIS_LEFT, // ignored for gemm/hemm/symm + bli_obj_length( &c_local ), + bli_obj_width( &c_local ), + bli_obj_width( &a_local ), + rntm + ); + + // Spawn threads (if applicable), where bls_gemm_int() is the thread entry + // point function for each thread. This also begins the process of creating + // the thrinfo_t tree, which contains thread communicators. + bls_l3_thread_decorator + ( + bls_gemm_int, + BLIS_GEMM, // operation family id + alpha, + &a_local, + &b_local, + beta, + &c_local, + cntx, + rntm + ); +} + +// +// -- Define the gemm-like operation's thread entry point ---------------------- +// + +void bls_gemm_int + ( + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + rntm_t* rntm, + thrinfo_t* thread + ) +{ + // In this function, we choose the gemm implementation that is executed + // on each thread. + +#if 1 + // Call the block-panel algorithm that calls the kernel directly, which + // exposes edge-case handling. + bls_gemm_bp_var1 + ( + alpha, + a, + b, + beta, + c, + cntx, + rntm, + thread + ); +#else + // Call the block-panel algorithm that calls the kernel indirectly via a + // wrapper function, which hides edge-case handling. + bls_gemm_bp_var2 + ( + alpha, + a, + b, + beta, + c, + cntx, + rntm, + thread + ); +#endif +} + +// +// -- Define the gemm-like operation's typed API ------------------------------- +// + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, opname ) \ +\ +void PASTECH2(bls_,ch,opname) \ + ( \ + trans_t transa, \ + trans_t transb, \ + dim_t m, \ + dim_t n, \ + dim_t k, \ + ctype* alpha, \ + ctype* a, inc_t rs_a, inc_t cs_a, \ + ctype* b, inc_t rs_b, inc_t cs_b, \ + ctype* beta, \ + ctype* c, inc_t rs_c, inc_t cs_c \ + ) \ +{ \ + bli_init_once(); \ +\ + /* Determine the datatype (e.g. BLIS_FLOAT, BLIS_DOUBLE, etc.) based on + the macro parameter 'ch' (e.g. s, d, etc). */ \ + const num_t dt = PASTEMAC(ch,type); \ +\ + obj_t alphao, ao, bo, betao, co; \ +\ + dim_t m_a, n_a; \ + dim_t m_b, n_b; \ +\ + /* Adjust the dimensions of matrices A and B according to the transa and + transb parameters. */ \ + bli_set_dims_with_trans( transa, m, k, &m_a, &n_a ); \ + bli_set_dims_with_trans( transb, k, n, &m_b, &n_b ); \ +\ + /* Create bufferless scalar objects and attach the provided scalar pointers + to those scalar objects. */ \ + bli_obj_create_1x1_with_attached_buffer( dt, alpha, &alphao ); \ + bli_obj_create_1x1_with_attached_buffer( dt, beta, &betao ); \ +\ + /* Create bufferless matrix objects and attach the provided matrix pointers + to those matrix objects. */ \ + bli_obj_create_with_attached_buffer( dt, m_a, n_a, a, rs_a, cs_a, &ao ); \ + bli_obj_create_with_attached_buffer( dt, m_b, n_b, b, rs_b, cs_b, &bo ); \ + bli_obj_create_with_attached_buffer( dt, m, n, c, rs_c, cs_c, &co ); \ +\ + /* Set the transposition/conjugation properties of the objects for matrices + A and B. */ \ + bli_obj_set_conjtrans( transa, &ao ); \ + bli_obj_set_conjtrans( transb, &bo ); \ +\ + /* Call the object interface. */ \ + PASTECH(bls_,opname) \ + ( \ + &alphao, \ + &ao, \ + &bo, \ + &betao, \ + &co \ + ); \ +} + +//INSERT_GENTFUNC_BASIC0( gemm ) +GENTFUNC( float, s, gemm ) +GENTFUNC( double, d, gemm ) +GENTFUNC( scomplex, c, gemm ) +GENTFUNC( dcomplex, z, gemm ) + diff --git a/sandbox/gemmlike/bls_gemm.h b/sandbox/gemmlike/bls_gemm.h new file mode 100644 index 0000000000..b296ac1c0f --- /dev/null +++ b/sandbox/gemmlike/bls_gemm.h @@ -0,0 +1,101 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +// +// -- Prototype the gemm-like operation's object API --------------------------- +// + +void bls_gemm + ( + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c + ); + +void bls_gemm_ex + ( + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + rntm_t* rntm + ); + +// +// -- Prototype the gemm-like operation's thread entry point ------------------- +// + +void bls_gemm_int + ( + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + rntm_t* rntm, + thrinfo_t* thread + ); + +// +// -- Prototype the gemm-like operation's typed API ---------------------------- +// + +#undef GENTPROT +#define GENTPROT( ctype, ch, opname ) \ +\ +void PASTECH2(bls_,ch,opname) \ + ( \ + trans_t transa, \ + trans_t transb, \ + dim_t m, \ + dim_t n, \ + dim_t k, \ + ctype* alpha, \ + ctype* a, inc_t rs_a, inc_t cs_a, \ + ctype* b, inc_t rs_b, inc_t cs_b, \ + ctype* beta, \ + ctype* c, inc_t rs_c, inc_t cs_c \ + ); + +//INSERT_GENTPROT_BASIC0( gemm ) +GENTPROT( float, s, gemm ) +GENTPROT( double, d, gemm ) +GENTPROT( scomplex, c, gemm ) +GENTPROT( dcomplex, z, gemm ) + diff --git a/sandbox/gemmlike/bls_gemm_bp_var1.c b/sandbox/gemmlike/bls_gemm_bp_var1.c new file mode 100644 index 0000000000..330a94801b --- /dev/null +++ b/sandbox/gemmlike/bls_gemm_bp_var1.c @@ -0,0 +1,521 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#define FUNCPTR_T gemm_fp + +typedef void (*FUNCPTR_T) + ( + conj_t conja, + conj_t conjb, + dim_t m, + dim_t n, + dim_t k, + void* restrict alpha, + void* restrict a, inc_t rs_a, inc_t cs_a, + void* restrict b, inc_t rs_b, inc_t cs_b, + void* restrict beta, + void* restrict c, inc_t rs_c, inc_t cs_c, + cntx_t* restrict cntx, + rntm_t* restrict rntm, + thrinfo_t* restrict thread + ); + +// +// -- gemm-like block-panel algorithm (object interface) ----------------------- +// + +// Define a function pointer array named ftypes and initialize its contents with +// the addresses of the typed functions defined below, bls_?gemm_bp_var1(). +static FUNCPTR_T GENARRAY_PREF(ftypes,bls_,gemm_bp_var1); + +void bls_gemm_bp_var1 + ( + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + rntm_t* rntm, + thrinfo_t* thread + ) +{ + const num_t dt = bli_obj_dt( c ); + + const conj_t conja = bli_obj_conj_status( a ); + const conj_t conjb = bli_obj_conj_status( b ); + + const dim_t m = bli_obj_length( c ); + const dim_t n = bli_obj_width( c ); + const dim_t k = bli_obj_width( a ); + + void* restrict buf_a = bli_obj_buffer_at_off( a ); + const inc_t rs_a = bli_obj_row_stride( a ); + const inc_t cs_a = bli_obj_col_stride( a ); + + void* restrict buf_b = bli_obj_buffer_at_off( b ); + const inc_t rs_b = bli_obj_row_stride( b ); + const inc_t cs_b = bli_obj_col_stride( b ); + + void* restrict buf_c = bli_obj_buffer_at_off( c ); + const inc_t rs_c = bli_obj_row_stride( c ); + const inc_t cs_c = bli_obj_col_stride( c ); + + void* restrict buf_alpha = bli_obj_buffer_for_1x1( dt, alpha ); + void* restrict buf_beta = bli_obj_buffer_for_1x1( dt, beta ); + + // Index into the function pointer array to extract the correct + // typed function pointer based on the chosen datatype. + FUNCPTR_T f = ftypes[dt]; + + // Invoke the function. + f + ( + conja, + conjb, + m, + n, + k, + buf_alpha, + buf_a, rs_a, cs_a, + buf_b, rs_b, cs_b, + buf_beta, + buf_c, rs_c, cs_c, + cntx, + rntm, + thread + ); +} + +// +// -- gemm-like block-panel algorithm (typed interface) ------------------------ +// + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, varname ) \ +\ +void PASTECH2(bls_,ch,varname) \ + ( \ + conj_t conja, \ + conj_t conjb, \ + dim_t m, \ + dim_t n, \ + dim_t k, \ + void* restrict alpha, \ + void* restrict a, inc_t rs_a, inc_t cs_a, \ + void* restrict b, inc_t rs_b, inc_t cs_b, \ + void* restrict beta, \ + void* restrict c, inc_t rs_c, inc_t cs_c, \ + cntx_t* restrict cntx, \ + rntm_t* restrict rntm, \ + thrinfo_t* restrict thread \ + ) \ +{ \ + const num_t dt = PASTEMAC(ch,type); \ +\ + /* Query the context for various blocksizes. */ \ + const dim_t NR = bli_cntx_get_blksz_def_dt( dt, BLIS_NR, cntx ); \ + const dim_t MR = bli_cntx_get_blksz_def_dt( dt, BLIS_MR, cntx ); \ + const dim_t NC = bli_cntx_get_blksz_def_dt( dt, BLIS_NC, cntx ); \ + const dim_t MC = bli_cntx_get_blksz_def_dt( dt, BLIS_MC, cntx ); \ + const dim_t KC = bli_cntx_get_blksz_def_dt( dt, BLIS_KC, cntx ); \ +\ + /* Query the context for the microkernel address and cast it to its + function pointer type. */ \ + PASTECH(ch,gemm_ukr_ft) \ + gemm_ukr = bli_cntx_get_l3_nat_ukr_dt( dt, BLIS_GEMM_UKR, cntx ); \ +\ + /* Temporary C buffer for edge cases. Note that the strides of this + temporary buffer are set so that they match the storage of the + original C matrix. For example, if C is column-stored, ct will be + column-stored as well. */ \ + ctype ct[ BLIS_STACK_BUF_MAX_SIZE \ + / sizeof( ctype ) ] \ + __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \ + const bool col_pref = bli_cntx_l3_nat_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, cntx ); \ + const inc_t rs_ct = ( col_pref ? 1 : NR ); \ + const inc_t cs_ct = ( col_pref ? MR : 1 ); \ +\ + /* Compute partitioning step values for each matrix of each loop. */ \ + const inc_t jcstep_c = cs_c; \ + const inc_t jcstep_b = cs_b; \ +\ + const inc_t pcstep_a = cs_a; \ + const inc_t pcstep_b = rs_b; \ +\ + const inc_t icstep_c = rs_c; \ + const inc_t icstep_a = rs_a; \ +\ + const inc_t jrstep_c = cs_c * NR; \ +\ + const inc_t irstep_c = rs_c * MR; \ +\ + ctype* restrict a_00 = a; \ + ctype* restrict b_00 = b; \ + ctype* restrict c_00 = c; \ + ctype* restrict alpha_cast = alpha; \ + ctype* restrict beta_cast = beta; \ +\ + /* Make local copies of the scalars to prevent any unnecessary sharing of + cache lines between the cores' caches. */ \ + ctype alpha_local = *alpha_cast; \ + ctype beta_local = *beta_cast; \ + ctype one_local = *PASTEMAC(ch,1); \ + ctype zero_local = *PASTEMAC(ch,0); \ +\ + auxinfo_t aux; \ +\ + /* Initialize a mem_t entry for A and B. Strictly speaking, this is only + needed for the matrix we will be packing (if any), but we do it + unconditionally to be safe. */ \ + mem_t mem_a = BLIS_MEM_INITIALIZER; \ + mem_t mem_b = BLIS_MEM_INITIALIZER; \ +\ + /* Define an array of bszid_t ids, which will act as our substitute for + the cntl_t tree. */ \ + bszid_t bszids[8] = { BLIS_NC, /* 5th loop */ \ + BLIS_KC, /* 4th loop */ \ + BLIS_NO_PART, /* pack B */ \ + BLIS_MC, /* 3rd loop */ \ + BLIS_NO_PART, /* pack A */ \ + BLIS_NR, /* 2nd loop */ \ + BLIS_MR, /* 1st loop */ \ + BLIS_KR }; /* microkernel loop */ \ +\ + bszid_t* restrict bszids_jc = &bszids[0]; \ + bszid_t* restrict bszids_pc = &bszids[1]; \ + /*bszid_t* restrict bszids_pb = &bszids[2];*/ \ + bszid_t* restrict bszids_ic = &bszids[3]; \ + /*bszid_t* restrict bszids_pa = &bszids[4];*/ \ + bszid_t* restrict bszids_jr = &bszids[5]; \ + /*bszid_t* restrict bszids_ir = &bszids[6];*/ \ +\ + thrinfo_t* restrict thread_jc = NULL; \ + thrinfo_t* restrict thread_pc = NULL; \ + thrinfo_t* restrict thread_pb = NULL; \ + thrinfo_t* restrict thread_ic = NULL; \ + thrinfo_t* restrict thread_pa = NULL; \ + thrinfo_t* restrict thread_jr = NULL; \ + thrinfo_t* restrict thread_ir = NULL; \ +\ + /* Clear the temporary C buffer in case it has any infs or NaNs. */ \ + PASTEMAC(ch,set0s_mxn)( MR, NR, ct, rs_ct, cs_ct ); \ +\ + /* Identify the current thrinfo_t node and then grow the tree. */ \ + thread_jc = thread; \ + bli_thrinfo_sup_grow( rntm, bszids_jc, thread_jc ); \ +\ + /* Compute the JC loop thread range for the current thread. */ \ + dim_t jc_start, jc_end; \ + bli_thread_range_sub( thread_jc, n, NR, FALSE, &jc_start, &jc_end ); \ + const dim_t n_local = jc_end - jc_start; \ +\ + /* Compute number of primary and leftover components of the JC loop. */ \ + /*const dim_t jc_iter = ( n_local + NC - 1 ) / NC;*/ \ + const dim_t jc_left = n_local % NC; \ +\ + /* Loop over the n dimension (NC rows/columns at a time). */ \ + for ( dim_t jj = jc_start; jj < jc_end; jj += NC ) \ + { \ + /* Calculate the thread's current JC block dimension. */ \ + const dim_t nc_cur = ( NC <= jc_end - jj ? NC : jc_left ); \ +\ + ctype* restrict b_jc = b_00 + jj * jcstep_b; \ + ctype* restrict c_jc = c_00 + jj * jcstep_c; \ +\ + /* Identify the current thrinfo_t node and then grow the tree. */ \ + thread_pc = bli_thrinfo_sub_node( thread_jc ); \ + bli_thrinfo_sup_grow( rntm, bszids_pc, thread_pc ); \ +\ + /* Compute the PC loop thread range for the current thread. */ \ + const dim_t pc_start = 0, pc_end = k; \ + const dim_t k_local = k; \ +\ + /* Compute number of primary and leftover components of the PC loop. */ \ + /*const dim_t pc_iter = ( k_local + KC - 1 ) / KC;*/ \ + const dim_t pc_left = k_local % KC; \ +\ + /* Loop over the k dimension (KC rows/columns at a time). */ \ + for ( dim_t pp = pc_start; pp < pc_end; pp += KC ) \ + { \ + /* Calculate the thread's current PC block dimension. */ \ + const dim_t kc_cur = ( KC <= pc_end - pp ? KC : pc_left ); \ +\ + ctype* restrict a_pc = a_00 + pp * pcstep_a; \ + ctype* restrict b_pc = b_jc + pp * pcstep_b; \ +\ + /* Only apply beta to the first iteration of the pc loop. */ \ + ctype* restrict beta_use = ( pp == 0 ? &beta_local : &one_local ); \ +\ + ctype* b_use; \ + inc_t rs_b_use, cs_b_use, ps_b_use; \ +\ + /* Identify the current thrinfo_t node. Note that the thrinfo_t + node will have already been created by a previous call to + bli_thrinfo_sup_grow() since bszid_t values of BLIS_NO_PART + cause the tree to grow by two (e.g. to the next bszid that is + a normal bszid_t value). */ \ + thread_pb = bli_thrinfo_sub_node( thread_pc ); \ + /*bli_thrinfo_sup_grow( rntm, bszids_pb, thread_pb );*/ \ +\ + /* Determine the packing buffer and related parameters for matrix + B. Then call the packm implementation. */ \ + PASTECH2(bls_,ch,packm_b) \ + ( \ + conjb, \ + KC, NC, \ + kc_cur, nc_cur, NR, \ + &one_local, \ + b_pc, rs_b, cs_b, \ + &b_use, &rs_b_use, &cs_b_use, \ + &ps_b_use, \ + cntx, \ + rntm, \ + &mem_b, \ + thread_pb \ + ); \ +\ + /* Alias b_use so that it's clear this is our current block of + matrix B. */ \ + ctype* restrict b_pc_use = b_use; \ +\ + /* Identify the current thrinfo_t node and then grow the tree. */ \ + thread_ic = bli_thrinfo_sub_node( thread_pb ); \ + bli_thrinfo_sup_grow( rntm, bszids_ic, thread_ic ); \ +\ + /* Compute the IC loop thread range for the current thread. */ \ + dim_t ic_start, ic_end; \ + bli_thread_range_sub( thread_ic, m, MR, FALSE, &ic_start, &ic_end ); \ + const dim_t m_local = ic_end - ic_start; \ +\ + /* Compute number of primary and leftover components of the IC loop. */ \ + /*const dim_t ic_iter = ( m_local + MC - 1 ) / MC;*/ \ + const dim_t ic_left = m_local % MC; \ +\ + /* Loop over the m dimension (MC rows at a time). */ \ + for ( dim_t ii = ic_start; ii < ic_end; ii += MC ) \ + { \ + /* Calculate the thread's current IC block dimension. */ \ + const dim_t mc_cur = ( MC <= ic_end - ii ? MC : ic_left ); \ +\ + ctype* restrict a_ic = a_pc + ii * icstep_a; \ + ctype* restrict c_ic = c_jc + ii * icstep_c; \ +\ + ctype* a_use; \ + inc_t rs_a_use, cs_a_use, ps_a_use; \ +\ + /* Identify the current thrinfo_t node. Note that the thrinfo_t + node will have already been created by a previous call to + bli_thrinfo_sup_grow() since bszid_t values of BLIS_NO_PART + cause the tree to grow by two (e.g. to the next bszid that is + a normal bszid_t value). */ \ + thread_pa = bli_thrinfo_sub_node( thread_ic ); \ + /*bli_thrinfo_sup_grow( rntm, bszids_pa, thread_pa );*/ \ +\ + /* Determine the packing buffer and related parameters for matrix + A. Then call the packm implementation. */ \ + PASTECH2(bls_,ch,packm_a) \ + ( \ + conja, \ + MC, KC, \ + mc_cur, kc_cur, MR, \ + &one_local, \ + a_ic, rs_a, cs_a, \ + &a_use, &rs_a_use, &cs_a_use, \ + &ps_a_use, \ + cntx, \ + rntm, \ + &mem_a, \ + thread_pa \ + ); \ +\ + /* Alias a_use so that it's clear this is our current block of + matrix A. */ \ + ctype* restrict a_ic_use = a_use; \ +\ + /* Identify the current thrinfo_t node and then grow the tree. */ \ + thread_jr = bli_thrinfo_sub_node( thread_pa ); \ + bli_thrinfo_sup_grow( rntm, bszids_jr, thread_jr ); \ +\ + /* Query the number of threads and thread ids for the JR loop. + NOTE: These values are only needed when computing the next + micropanel of B. */ \ + const dim_t jr_nt = bli_thread_n_way( thread_jr ); \ + const dim_t jr_tid = bli_thread_work_id( thread_jr ); \ +\ + /* Compute number of primary and leftover components of the JR loop. */ \ + dim_t jr_iter = ( nc_cur + NR - 1 ) / NR; \ + dim_t jr_left = nc_cur % NR; \ +\ + /* Compute the JR loop thread range for the current thread. */ \ + dim_t jr_start, jr_end; \ + bli_thread_range_sub( thread_jr, jr_iter, 1, FALSE, &jr_start, &jr_end ); \ +\ + /* Loop over the n dimension (NR columns at a time). */ \ + for ( dim_t j = jr_start; j < jr_end; j += 1 ) \ + { \ + const dim_t nr_cur \ + = ( bli_is_not_edge_f( j, jr_iter, jr_left ) ? NR : jr_left ); \ +\ + ctype* restrict b_jr = b_pc_use + j * ps_b_use; \ + ctype* restrict c_jr = c_ic + j * jrstep_c; \ +\ + /* Assume for now that our next panel of B to be the current panel + of B. */ \ + ctype* restrict b2 = b_jr; \ +\ + /* Identify the current thrinfo_t node. */ \ + thread_ir = bli_thrinfo_sub_node( thread_jr ); \ +\ + /* Query the number of threads and thread ids for the IR loop. + NOTE: These values are only needed when computing the next + micropanel of A. */ \ + const dim_t ir_nt = bli_thread_n_way( thread_ir ); \ + const dim_t ir_tid = bli_thread_work_id( thread_ir ); \ +\ + /* Compute number of primary and leftover components of the IR loop. */ \ + dim_t ir_iter = ( mc_cur + MR - 1 ) / MR; \ + dim_t ir_left = mc_cur % MR; \ +\ + /* Compute the IR loop thread range for the current thread. */ \ + dim_t ir_start, ir_end; \ + bli_thread_range_sub( thread_ir, ir_iter, 1, FALSE, &ir_start, &ir_end ); \ +\ + /* Loop over the m dimension (MR rows at a time). */ \ + for ( dim_t i = ir_start; i < ir_end; i += 1 ) \ + { \ + const dim_t mr_cur \ + = ( bli_is_not_edge_f( i, ir_iter, ir_left ) ? MR : ir_left ); \ +\ + ctype* restrict a_ir = a_ic_use + i * ps_a_use; \ + ctype* restrict c_ir = c_jr + i * irstep_c; \ +\ + ctype* restrict a2; \ +\ + /* Compute the addresses of the next micropanels of A and B. */ \ + a2 = bli_gemm_get_next_a_upanel( a_ir, ps_a_use, 1 ); \ + if ( bli_is_last_iter( i, ir_end, ir_tid, ir_nt ) ) \ + { \ + a2 = a_ic_use; \ + b2 = bli_gemm_get_next_b_upanel( b_jr, ps_b_use, 1 ); \ + if ( bli_is_last_iter( j, jr_end, jr_tid, jr_nt ) ) \ + b2 = b_pc_use; \ + } \ +\ + /* Save the addresses of next micropanels of A and B to the + auxinfo_t object. */ \ + bli_auxinfo_set_next_a( a2, &aux ); \ + bli_auxinfo_set_next_b( b2, &aux ); \ +\ + /* Handle interior and edge cases separately. */ \ + if ( mr_cur == MR && nr_cur == NR ) \ + { \ + /* Invoke the gemm microkernel. */ \ + gemm_ukr \ + ( \ + kc_cur, \ + &alpha_local, \ + a_ir, \ + b_jr, \ + beta_use, \ + c_ir, rs_c, cs_c, \ + &aux, \ + cntx \ + ); \ + } \ + else \ + { \ + /* Invoke the gemm microkernel. */ \ + gemm_ukr \ + ( \ + kc_cur, \ + &alpha_local, \ + a_ir, \ + b_jr, \ + &zero_local, \ + ct, rs_ct, cs_ct, \ + &aux, \ + cntx \ + ); \ +\ + /* Scale the bottom edge of C and add the result from above. */ \ + PASTEMAC(ch,xpbys_mxn) \ + ( \ + mr_cur, \ + nr_cur, \ + ct, rs_ct, cs_ct, \ + beta_use, \ + c_ir, rs_c, cs_c \ + ); \ + } \ + } \ + } \ + } \ +\ + /* This barrier is needed to prevent threads from starting to pack + the next row panel of B before the current row panel is fully + computed upon. */ \ + bli_thread_barrier( thread_pb ); \ + } \ + } \ +\ + /* Release any memory that was acquired for packing matrices A and B. */ \ + PASTECH2(bls_,ch,packm_finalize_mem_a) \ + ( \ + rntm, \ + &mem_a, \ + thread_pa \ + ); \ + PASTECH2(bls_,ch,packm_finalize_mem_b) \ + ( \ + rntm, \ + &mem_b, \ + thread_pb \ + ); \ +\ +/* +PASTEMAC(ch,fprintm)( stdout, "gemm_bp_var1: a1_packed", mr_cur, kc_cur, a_ir, rs_a_use, cs_a_use, "%5.2f", "" ); \ +PASTEMAC(ch,fprintm)( stdout, "gemm_bp_var1: b1_packed", kc_cur, nr_cur, b_jr, rs_b_use, cs_b_use, "%5.2f", "" ); \ +PASTEMAC(ch,fprintm)( stdout, "gemm_bp_var1: c ", mr_cur, nr_cur, c_ir, rs_c, cs_c, "%5.2f", "" ); \ +*/ \ +} + +//INSERT_GENTFUNC_BASIC0( gemm_bp_var1 ) +GENTFUNC( float, s, gemm_bp_var1 ) +GENTFUNC( double, d, gemm_bp_var1 ) +GENTFUNC( scomplex, c, gemm_bp_var1 ) +GENTFUNC( dcomplex, z, gemm_bp_var1 ) + diff --git a/sandbox/gemmlike/bls_gemm_bp_var2.c b/sandbox/gemmlike/bls_gemm_bp_var2.c new file mode 100644 index 0000000000..22df767aea --- /dev/null +++ b/sandbox/gemmlike/bls_gemm_bp_var2.c @@ -0,0 +1,596 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#define FUNCPTR_T gemm_fp + +typedef void (*FUNCPTR_T) + ( + conj_t conja, + conj_t conjb, + dim_t m, + dim_t n, + dim_t k, + void* restrict alpha, + void* restrict a, inc_t rs_a, inc_t cs_a, + void* restrict b, inc_t rs_b, inc_t cs_b, + void* restrict beta, + void* restrict c, inc_t rs_c, inc_t cs_c, + cntx_t* restrict cntx, + rntm_t* restrict rntm, + thrinfo_t* restrict thread + ); + +// +// -- gemm-like block-panel algorithm (object interface) ----------------------- +// + +// Define a function pointer array named ftypes and initialize its contents with +// the addresses of the typed functions defined below, bls_?gemm_bp_var2(). +static FUNCPTR_T GENARRAY_PREF(ftypes,bls_,gemm_bp_var2); + +void bls_gemm_bp_var2 + ( + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + rntm_t* rntm, + thrinfo_t* thread + ) +{ + const num_t dt = bli_obj_dt( c ); + + const conj_t conja = bli_obj_conj_status( a ); + const conj_t conjb = bli_obj_conj_status( b ); + + const dim_t m = bli_obj_length( c ); + const dim_t n = bli_obj_width( c ); + const dim_t k = bli_obj_width( a ); + + void* restrict buf_a = bli_obj_buffer_at_off( a ); + const inc_t rs_a = bli_obj_row_stride( a ); + const inc_t cs_a = bli_obj_col_stride( a ); + + void* restrict buf_b = bli_obj_buffer_at_off( b ); + const inc_t rs_b = bli_obj_row_stride( b ); + const inc_t cs_b = bli_obj_col_stride( b ); + + void* restrict buf_c = bli_obj_buffer_at_off( c ); + const inc_t rs_c = bli_obj_row_stride( c ); + const inc_t cs_c = bli_obj_col_stride( c ); + + void* restrict buf_alpha = bli_obj_buffer_for_1x1( dt, alpha ); + void* restrict buf_beta = bli_obj_buffer_for_1x1( dt, beta ); + + // Index into the function pointer array to extract the correct + // typed function pointer based on the chosen datatype. + FUNCPTR_T f = ftypes[dt]; + + // Invoke the function. + f + ( + conja, + conjb, + m, + n, + k, + buf_alpha, + buf_a, rs_a, cs_a, + buf_b, rs_b, cs_b, + buf_beta, + buf_c, rs_c, cs_c, + cntx, + rntm, + thread + ); +} + +// +// -- gemm-like block-panel algorithm (typed interface) ------------------------ +// + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, varname ) \ +\ +void PASTECH2(bls_,ch,varname) \ + ( \ + conj_t conja, \ + conj_t conjb, \ + dim_t m, \ + dim_t n, \ + dim_t k, \ + void* restrict alpha, \ + void* restrict a, inc_t rs_a, inc_t cs_a, \ + void* restrict b, inc_t rs_b, inc_t cs_b, \ + void* restrict beta, \ + void* restrict c, inc_t rs_c, inc_t cs_c, \ + cntx_t* restrict cntx, \ + rntm_t* restrict rntm, \ + thrinfo_t* restrict thread \ + ) \ +{ \ + const num_t dt = PASTEMAC(ch,type); \ +\ + /* Query the context for various blocksizes. */ \ + const dim_t NR = bli_cntx_get_blksz_def_dt( dt, BLIS_NR, cntx ); \ + const dim_t MR = bli_cntx_get_blksz_def_dt( dt, BLIS_MR, cntx ); \ + const dim_t NC = bli_cntx_get_blksz_def_dt( dt, BLIS_NC, cntx ); \ + const dim_t MC = bli_cntx_get_blksz_def_dt( dt, BLIS_MC, cntx ); \ + const dim_t KC = bli_cntx_get_blksz_def_dt( dt, BLIS_KC, cntx ); \ +\ + /* Query the context for the microkernel address and cast it to its + function pointer type. */ \ + /* + PASTECH(ch,gemm_ukr_ft) \ + gemm_ukr = bli_cntx_get_l3_nat_ukr_dt( dt, BLIS_GEMM_UKR, cntx ); \ + */ \ +\ + /* Temporary C buffer for edge cases. Note that the strides of this + temporary buffer are set so that they match the storage of the + original C matrix. For example, if C is column-stored, ct will be + column-stored as well. */ \ + /* + ctype ct[ BLIS_STACK_BUF_MAX_SIZE \ + / sizeof( ctype ) ] \ + __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \ + const bool col_pref = bli_cntx_l3_nat_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, cntx ); \ + const inc_t rs_ct = ( col_pref ? 1 : NR ); \ + const inc_t cs_ct = ( col_pref ? MR : 1 ); \ + */ \ +\ + /* Compute partitioning step values for each matrix of each loop. */ \ + const inc_t jcstep_c = cs_c; \ + const inc_t jcstep_b = cs_b; \ +\ + const inc_t pcstep_a = cs_a; \ + const inc_t pcstep_b = rs_b; \ +\ + const inc_t icstep_c = rs_c; \ + const inc_t icstep_a = rs_a; \ +\ + const inc_t jrstep_c = cs_c * NR; \ +\ + const inc_t irstep_c = rs_c * MR; \ +\ + ctype* restrict a_00 = a; \ + ctype* restrict b_00 = b; \ + ctype* restrict c_00 = c; \ + ctype* restrict alpha_cast = alpha; \ + ctype* restrict beta_cast = beta; \ +\ + /* Make local copies of the scalars to prevent any unnecessary sharing of + cache lines between the cores' caches. */ \ + ctype alpha_local = *alpha_cast; \ + ctype beta_local = *beta_cast; \ + ctype one_local = *PASTEMAC(ch,1); \ + /*ctype zero_local = *PASTEMAC(ch,0);*/ \ +\ + auxinfo_t aux; \ +\ + /* Initialize a mem_t entry for A and B. Strictly speaking, this is only + needed for the matrix we will be packing (if any), but we do it + unconditionally to be safe. */ \ + mem_t mem_a = BLIS_MEM_INITIALIZER; \ + mem_t mem_b = BLIS_MEM_INITIALIZER; \ +\ + /* Define an array of bszid_t ids, which will act as our substitute for + the cntl_t tree. */ \ + bszid_t bszids[8] = { BLIS_NC, /* 5th loop */ \ + BLIS_KC, /* 4th loop */ \ + BLIS_NO_PART, /* pack B */ \ + BLIS_MC, /* 3rd loop */ \ + BLIS_NO_PART, /* pack A */ \ + BLIS_NR, /* 2nd loop */ \ + BLIS_MR, /* 1st loop */ \ + BLIS_KR }; /* microkernel loop */ \ +\ + bszid_t* restrict bszids_jc = &bszids[0]; \ + bszid_t* restrict bszids_pc = &bszids[1]; \ + /*bszid_t* restrict bszids_pb = &bszids[2];*/ \ + bszid_t* restrict bszids_ic = &bszids[3]; \ + /*bszid_t* restrict bszids_pa = &bszids[4];*/ \ + bszid_t* restrict bszids_jr = &bszids[5]; \ + /*bszid_t* restrict bszids_ir = &bszids[6];*/ \ +\ + thrinfo_t* restrict thread_jc = NULL; \ + thrinfo_t* restrict thread_pc = NULL; \ + thrinfo_t* restrict thread_pb = NULL; \ + thrinfo_t* restrict thread_ic = NULL; \ + thrinfo_t* restrict thread_pa = NULL; \ + thrinfo_t* restrict thread_jr = NULL; \ + thrinfo_t* restrict thread_ir = NULL; \ +\ + /* Identify the current thrinfo_t node and then grow the tree. */ \ + thread_jc = thread; \ + bli_thrinfo_sup_grow( rntm, bszids_jc, thread_jc ); \ +\ + /* Compute the JC loop thread range for the current thread. */ \ + dim_t jc_start, jc_end; \ + bli_thread_range_sub( thread_jc, n, NR, FALSE, &jc_start, &jc_end ); \ + const dim_t n_local = jc_end - jc_start; \ +\ + /* Compute number of primary and leftover components of the JC loop. */ \ + /*const dim_t jc_iter = ( n_local + NC - 1 ) / NC;*/ \ + const dim_t jc_left = n_local % NC; \ +\ + /* Loop over the n dimension (NC rows/columns at a time). */ \ + for ( dim_t jj = jc_start; jj < jc_end; jj += NC ) \ + { \ + /* Calculate the thread's current JC block dimension. */ \ + const dim_t nc_cur = ( NC <= jc_end - jj ? NC : jc_left ); \ +\ + ctype* restrict b_jc = b_00 + jj * jcstep_b; \ + ctype* restrict c_jc = c_00 + jj * jcstep_c; \ +\ + /* Identify the current thrinfo_t node and then grow the tree. */ \ + thread_pc = bli_thrinfo_sub_node( thread_jc ); \ + bli_thrinfo_sup_grow( rntm, bszids_pc, thread_pc ); \ +\ + /* Compute the PC loop thread range for the current thread. */ \ + const dim_t pc_start = 0, pc_end = k; \ + const dim_t k_local = k; \ +\ + /* Compute number of primary and leftover components of the PC loop. */ \ + /*const dim_t pc_iter = ( k_local + KC - 1 ) / KC;*/ \ + const dim_t pc_left = k_local % KC; \ +\ + /* Loop over the k dimension (KC rows/columns at a time). */ \ + for ( dim_t pp = pc_start; pp < pc_end; pp += KC ) \ + { \ + /* Calculate the thread's current PC block dimension. */ \ + const dim_t kc_cur = ( KC <= pc_end - pp ? KC : pc_left ); \ +\ + ctype* restrict a_pc = a_00 + pp * pcstep_a; \ + ctype* restrict b_pc = b_jc + pp * pcstep_b; \ +\ + /* Only apply beta to the first iteration of the pc loop. */ \ + ctype* restrict beta_use = ( pp == 0 ? &beta_local : &one_local ); \ +\ + ctype* b_use; \ + inc_t rs_b_use, cs_b_use, ps_b_use; \ +\ + /* Identify the current thrinfo_t node. Note that the thrinfo_t + node will have already been created by a previous call to + bli_thrinfo_sup_grow() since bszid_t values of BLIS_NO_PART + cause the tree to grow by two (e.g. to the next bszid that is + a normal bszid_t value). */ \ + thread_pb = bli_thrinfo_sub_node( thread_pc ); \ + /*bli_thrinfo_sup_grow( rntm, bszids_pb, thread_pb );*/ \ +\ + /* Determine the packing buffer and related parameters for matrix + B. Then call the packm implementation. */ \ + PASTECH2(bls_,ch,packm_b) \ + ( \ + conjb, \ + KC, NC, \ + kc_cur, nc_cur, NR, \ + &one_local, \ + b_pc, rs_b, cs_b, \ + &b_use, &rs_b_use, &cs_b_use, \ + &ps_b_use, \ + cntx, \ + rntm, \ + &mem_b, \ + thread_pb \ + ); \ +\ + /* Alias b_use so that it's clear this is our current block of + matrix B. */ \ + ctype* restrict b_pc_use = b_use; \ +\ + /* Identify the current thrinfo_t node and then grow the tree. */ \ + thread_ic = bli_thrinfo_sub_node( thread_pb ); \ + bli_thrinfo_sup_grow( rntm, bszids_ic, thread_ic ); \ +\ + /* Compute the IC loop thread range for the current thread. */ \ + dim_t ic_start, ic_end; \ + bli_thread_range_sub( thread_ic, m, MR, FALSE, &ic_start, &ic_end ); \ + const dim_t m_local = ic_end - ic_start; \ +\ + /* Compute number of primary and leftover components of the IC loop. */ \ + /*const dim_t ic_iter = ( m_local + MC - 1 ) / MC;*/ \ + const dim_t ic_left = m_local % MC; \ +\ + /* Loop over the m dimension (MC rows at a time). */ \ + for ( dim_t ii = ic_start; ii < ic_end; ii += MC ) \ + { \ + /* Calculate the thread's current IC block dimension. */ \ + const dim_t mc_cur = ( MC <= ic_end - ii ? MC : ic_left ); \ +\ + ctype* restrict a_ic = a_pc + ii * icstep_a; \ + ctype* restrict c_ic = c_jc + ii * icstep_c; \ +\ + ctype* a_use; \ + inc_t rs_a_use, cs_a_use, ps_a_use; \ +\ + /* Identify the current thrinfo_t node. Note that the thrinfo_t + node will have already been created by a previous call to + bli_thrinfo_sup_grow() since bszid_t values of BLIS_NO_PART + cause the tree to grow by two (e.g. to the next bszid that is + a normal bszid_t value). */ \ + thread_pa = bli_thrinfo_sub_node( thread_ic ); \ + /*bli_thrinfo_sup_grow( rntm, bszids_pa, thread_pa );*/ \ +\ + /* Determine the packing buffer and related parameters for matrix + A. Then call the packm implementation. */ \ + PASTECH2(bls_,ch,packm_a) \ + ( \ + conja, \ + MC, KC, \ + mc_cur, kc_cur, MR, \ + &one_local, \ + a_ic, rs_a, cs_a, \ + &a_use, &rs_a_use, &cs_a_use, \ + &ps_a_use, \ + cntx, \ + rntm, \ + &mem_a, \ + thread_pa \ + ); \ +\ + /* Alias a_use so that it's clear this is our current block of + matrix A. */ \ + ctype* restrict a_ic_use = a_use; \ +\ + /* Identify the current thrinfo_t node and then grow the tree. */ \ + thread_jr = bli_thrinfo_sub_node( thread_pa ); \ + bli_thrinfo_sup_grow( rntm, bszids_jr, thread_jr ); \ +\ + /* Query the number of threads and thread ids for the JR loop. + NOTE: These values are only needed when computing the next + micropanel of B. */ \ + const dim_t jr_nt = bli_thread_n_way( thread_jr ); \ + const dim_t jr_tid = bli_thread_work_id( thread_jr ); \ +\ + /* Compute number of primary and leftover components of the JR loop. */ \ + dim_t jr_iter = ( nc_cur + NR - 1 ) / NR; \ + dim_t jr_left = nc_cur % NR; \ +\ + /* Compute the JR loop thread range for the current thread. */ \ + dim_t jr_start, jr_end; \ + bli_thread_range_sub( thread_jr, jr_iter, 1, FALSE, &jr_start, &jr_end ); \ +\ + /* Loop over the n dimension (NR columns at a time). */ \ + for ( dim_t j = jr_start; j < jr_end; j += 1 ) \ + { \ + const dim_t nr_cur \ + = ( bli_is_not_edge_f( j, jr_iter, jr_left ) ? NR : jr_left ); \ +\ + ctype* restrict b_jr = b_pc_use + j * ps_b_use; \ + ctype* restrict c_jr = c_ic + j * jrstep_c; \ +\ + /* Assume for now that our next panel of B to be the current panel + of B. */ \ + ctype* restrict b2 = b_jr; \ +\ + /* Identify the current thrinfo_t node. */ \ + thread_ir = bli_thrinfo_sub_node( thread_jr ); \ +\ + /* Query the number of threads and thread ids for the IR loop. + NOTE: These values are only needed when computing the next + micropanel of A. */ \ + const dim_t ir_nt = bli_thread_n_way( thread_ir ); \ + const dim_t ir_tid = bli_thread_work_id( thread_ir ); \ +\ + /* Compute number of primary and leftover components of the IR loop. */ \ + dim_t ir_iter = ( mc_cur + MR - 1 ) / MR; \ + dim_t ir_left = mc_cur % MR; \ +\ + /* Compute the IR loop thread range for the current thread. */ \ + dim_t ir_start, ir_end; \ + bli_thread_range_sub( thread_ir, ir_iter, 1, FALSE, &ir_start, &ir_end ); \ +\ + /* Loop over the m dimension (MR rows at a time). */ \ + for ( dim_t i = ir_start; i < ir_end; i += 1 ) \ + { \ + const dim_t mr_cur \ + = ( bli_is_not_edge_f( i, ir_iter, ir_left ) ? MR : ir_left ); \ +\ + ctype* restrict a_ir = a_ic_use + i * ps_a_use; \ + ctype* restrict c_ir = c_jr + i * irstep_c; \ +\ + ctype* restrict a2; \ +\ + /* Compute the addresses of the next micropanels of A and B. */ \ + a2 = bli_gemm_get_next_a_upanel( a_ir, ps_a_use, 1 ); \ + if ( bli_is_last_iter( i, ir_end, ir_tid, ir_nt ) ) \ + { \ + a2 = a_ic_use; \ + b2 = bli_gemm_get_next_b_upanel( b_jr, ps_b_use, 1 ); \ + if ( bli_is_last_iter( j, jr_end, jr_tid, jr_nt ) ) \ + b2 = b_pc_use; \ + } \ +\ + /* Save the addresses of next micropanels of A and B to the + auxinfo_t object. */ \ + bli_auxinfo_set_next_a( a2, &aux ); \ + bli_auxinfo_set_next_b( b2, &aux ); \ +\ + /* Call a wrapper to the kernel (which handles edge cases). */ \ + PASTECH2(bls_,ch,gemm_kernel) \ + ( \ + MR, \ + NR, \ + mr_cur, \ + nr_cur, \ + kc_cur, \ + &alpha_local, \ + a_ir, rs_a_use, cs_a_use, \ + b_jr, rs_b_use, cs_b_use, \ + beta_use, \ + c_ir, rs_c, cs_c, \ + &aux, \ + cntx \ + ); \ + } \ + } \ + } \ +\ + /* This barrier is needed to prevent threads from starting to pack + the next row panel of B before the current row panel is fully + computed upon. */ \ + bli_thread_barrier( thread_pb ); \ + } \ + } \ +\ + /* Release any memory that was acquired for packing matrices A and B. */ \ + PASTECH2(bls_,ch,packm_finalize_mem_a) \ + ( \ + rntm, \ + &mem_a, \ + thread_pa \ + ); \ + PASTECH2(bls_,ch,packm_finalize_mem_b) \ + ( \ + rntm, \ + &mem_b, \ + thread_pb \ + ); \ +\ +/* +PASTEMAC(ch,fprintm)( stdout, "gemm_bp_var2: a1_packed", mr_cur, kc_cur, a_ir, rs_a_use, cs_a_use, "%5.2f", "" ); \ +PASTEMAC(ch,fprintm)( stdout, "gemm_bp_var2: b1_packed", kc_cur, nr_cur, b_jr, rs_b_use, cs_b_use, "%5.2f", "" ); \ +PASTEMAC(ch,fprintm)( stdout, "gemm_bp_var2: c ", mr_cur, nr_cur, c_ir, rs_c, cs_c, "%5.2f", "" ); \ +*/ \ +} + +//INSERT_GENTFUNC_BASIC0( gemm_bp_var2 ) +GENTFUNC( float, s, gemm_bp_var2 ) +GENTFUNC( double, d, gemm_bp_var2 ) +GENTFUNC( scomplex, c, gemm_bp_var2 ) +GENTFUNC( dcomplex, z, gemm_bp_var2 ) + +// +// -- gemm-like microkernel wrapper -------------------------------------------- +// + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, varname ) \ +\ +void PASTECH2(bls_,ch,varname) \ + ( \ + const dim_t MR, \ + const dim_t NR, \ + dim_t mr_cur, \ + dim_t nr_cur, \ + dim_t kc_cur, \ + ctype* restrict alpha, \ + ctype* restrict a, inc_t rs_a, inc_t cs_a, \ + ctype* restrict b, inc_t rs_b, inc_t cs_b, \ + ctype* restrict beta, \ + ctype* restrict c, inc_t rs_c, inc_t cs_c, \ + auxinfo_t* restrict aux, \ + cntx_t* restrict cntx \ + ) \ +{ \ + /* Infer the datatype from the ctype. */ \ + const num_t dt = PASTEMAC(ch,type); \ +\ + /* Query the context for the microkernel address and cast it to its + function pointer type. */ \ + PASTECH(ch,gemm_ukr_ft) \ + gemm_ukr = bli_cntx_get_l3_nat_ukr_dt( dt, BLIS_GEMM_UKR, cntx ); \ +\ + /* Temporary C buffer for edge cases. Note that the strides of this + temporary buffer are set so that they match the storage of the + original C matrix. For example, if C is column-stored, ct will be + column-stored as well. */ \ + ctype ct[ BLIS_STACK_BUF_MAX_SIZE \ + / sizeof( ctype ) ] \ + __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \ + const bool col_pref = bli_cntx_l3_nat_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, cntx ); \ + const inc_t rs_ct = ( col_pref ? 1 : NR ); \ + const inc_t cs_ct = ( col_pref ? MR : 1 ); \ +\ + ctype zero = *PASTEMAC(ch,0); \ +\ + /* Clear the temporary C buffer in case it has any infs or NaNs. + NOTE: This initialization should really be done statically since + var2 executes this microkernel wrapper many times, and the overhead + of touching the temporary microtile adds up. */ \ + PASTEMAC(ch,set0s_mxn)( MR, NR, ct, rs_ct, cs_ct ); \ +\ + /* Handle interior and edge cases separately. */ \ + if ( mr_cur == MR && nr_cur == NR ) \ + { \ + /* Invoke the gemm microkernel. */ \ + gemm_ukr \ + ( \ + kc_cur, \ + alpha, \ + a, \ + b, \ + beta, \ + c, rs_c, cs_c, \ + aux, \ + cntx \ + ); \ + } \ + else \ + { \ + /* Invoke the gemm microkernel. */ \ + gemm_ukr \ + ( \ + kc_cur, \ + alpha, \ + a, \ + b, \ + &zero, \ + ct, rs_ct, cs_ct, \ + aux, \ + cntx \ + ); \ +\ + /* Scale the bottom edge of C and add the result from above. */ \ + PASTEMAC(ch,xpbys_mxn) \ + ( \ + mr_cur, \ + nr_cur, \ + ct, rs_ct, cs_ct, \ + beta, \ + c, rs_c, cs_c \ + ); \ + } \ +} + +//INSERT_GENTFUNC_BASIC0( gemm_kernel ) +GENTFUNC( float, s, gemm_kernel ) +GENTFUNC( double, d, gemm_kernel ) +GENTFUNC( scomplex, c, gemm_kernel ) +GENTFUNC( dcomplex, z, gemm_kernel ) + diff --git a/sandbox/gemmlike/bls_gemm_var.h b/sandbox/gemmlike/bls_gemm_var.h new file mode 100644 index 0000000000..025b54a06f --- /dev/null +++ b/sandbox/gemmlike/bls_gemm_var.h @@ -0,0 +1,124 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + + +// +// Prototype the object-based variant interfaces. +// + +#undef GENPROT +#define GENPROT( opname ) \ +\ +void PASTECH(bls_,opname) \ + ( \ + obj_t* alpha, \ + obj_t* a, \ + obj_t* b, \ + obj_t* beta, \ + obj_t* c, \ + cntx_t* cntx, \ + rntm_t* rntm, \ + thrinfo_t* thread \ + ); + +GENPROT( gemm_bp_var1 ) +GENPROT( gemm_bp_var2 ) + + +// +// Prototype the typed variant interfaces. +// + +#undef GENTPROT +#define GENTPROT( ctype, ch, varname ) \ +\ +void PASTECH2(bls_,ch,varname) \ + ( \ + conj_t conja, \ + conj_t conjb, \ + dim_t m, \ + dim_t n, \ + dim_t k, \ + void* restrict alpha, \ + void* restrict a, inc_t rs_a, inc_t cs_a, \ + void* restrict b, inc_t rs_b, inc_t cs_b, \ + void* restrict beta, \ + void* restrict c, inc_t rs_c, inc_t cs_c, \ + cntx_t* restrict cntx, \ + rntm_t* restrict rntm, \ + thrinfo_t* restrict thread \ + ); + +//INSERT_GENTPROT_BASIC0( gemm_bp_var1 ) +GENTPROT( float, s, gemm_bp_var1 ) +GENTPROT( double, d, gemm_bp_var1 ) +GENTPROT( scomplex, c, gemm_bp_var1 ) +GENTPROT( dcomplex, z, gemm_bp_var1 ) + +//INSERT_GENTPROT_BASIC0( gemm_bp_var2 ) +GENTPROT( float, s, gemm_bp_var2 ) +GENTPROT( double, d, gemm_bp_var2 ) +GENTPROT( scomplex, c, gemm_bp_var2 ) +GENTPROT( dcomplex, z, gemm_bp_var2 ) + + +// +// Prototype the typed kernel interfaces. +// + +#undef GENTPROT +#define GENTPROT( ctype, ch, varname ) \ +\ +void PASTECH2(bls_,ch,varname) \ + ( \ + const dim_t MR, \ + const dim_t NR, \ + dim_t mr_cur, \ + dim_t nr_cur, \ + dim_t k, \ + ctype* restrict alpha, \ + ctype* restrict a, inc_t rs_a, inc_t cs_a, \ + ctype* restrict b, inc_t rs_b, inc_t cs_b, \ + ctype* restrict beta, \ + ctype* restrict c, inc_t rs_c, inc_t cs_c, \ + auxinfo_t* restrict aux, \ + cntx_t* restrict cntx \ + ); + +//INSERT_GENTPROT_BASIC0( gemm_kernel ) +GENTPROT( float, s, gemm_kernel ) +GENTPROT( double, d, gemm_kernel ) +GENTPROT( scomplex, c, gemm_kernel ) +GENTPROT( dcomplex, z, gemm_kernel ) + diff --git a/sandbox/gemmlike/bls_l3_packm_a.c b/sandbox/gemmlike/bls_l3_packm_a.c new file mode 100644 index 0000000000..c55a19c7b7 --- /dev/null +++ b/sandbox/gemmlike/bls_l3_packm_a.c @@ -0,0 +1,328 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, opname ) \ +\ +void PASTECH2(bls_,ch,opname) \ + ( \ + dim_t m, \ + dim_t k, \ + dim_t mr, \ + cntx_t* restrict cntx, \ + rntm_t* restrict rntm, \ + mem_t* restrict mem, \ + thrinfo_t* restrict thread \ + ) \ +{ \ + /* Set the pack buffer type so that we are obtaining memory blocks from + the pool dedicated to blocks of A. */ \ + const packbuf_t pack_buf_type = BLIS_BUFFER_FOR_A_BLOCK; \ +\ + /* NOTE: This "rounding up" of the last upanel is absolutely necessary since + we NEED that last micropanel to have the same ldim (cs_p) as the other + micropanels. Why? Because the microkernel assumes that the register (MR, + NR) AND storage (PACKMR, PACKNR) blocksizes do not change. */ \ + const dim_t m_pack = ( m / mr + ( m % mr ? 1 : 0 ) ) * mr; \ + const dim_t k_pack = k; \ +\ + /* Barrier to make sure all threads are caught up and ready to begin the + packm stage. */ \ + bli_thread_barrier( thread ); \ +\ + /* Compute the size of the memory block eneded. */ \ + siz_t size_needed = sizeof( ctype ) * m_pack * k_pack; \ +\ + /* Check the mem_t entry provided by the caller. If it is unallocated, + then we need to acquire a block from the memory broker. */ \ + if ( bli_mem_is_unalloc( mem ) ) \ + { \ + if ( bli_thread_am_ochief( thread ) ) \ + { \ + /* Acquire directly to the chief thread's mem_t that was passed in. + It needs to be that mem_t struct, and not a local (temporary) + mem_t, since there is no barrier until after packing is finished, + which could allow a race condition whereby the chief thread exits + the current function before the other threads have a chance to + copy from it. (A barrier would fix that race condition, but then + again, I prefer to keep barriers to a minimum.) */ \ + bli_membrk_acquire_m \ + ( \ + rntm, \ + size_needed, \ + pack_buf_type, \ + mem \ + ); \ + } \ +\ + /* Broadcast the address of the chief thread's passed-in mem_t to all + threads. */ \ + mem_t* mem_p = bli_thread_broadcast( thread, mem ); \ +\ + /* Non-chief threads: Copy the contents of the chief thread's + passed-in mem_t to the passed-in mem_t for this thread. (The + chief thread already has the mem_t, so it does not need to + perform any copy.) */ \ + if ( !bli_thread_am_ochief( thread ) ) \ + { \ + *mem = *mem_p; \ + } \ + } \ + else /* if ( bli_mem_is_alloc( mem ) ) */ \ + { \ + /* If the mem_t entry provided by the caller does NOT contain a NULL + buffer, then a block has already been acquired from the memory + broker and cached by the caller. */ \ +\ + /* As a sanity check, we should make sure that the mem_t object isn't + associated with a block that is too small compared to the size of + the packed matrix buffer that is needed, according to the value + computed above. */ \ + siz_t mem_size = bli_mem_size( mem ); \ +\ + if ( mem_size < size_needed ) \ + { \ + if ( bli_thread_am_ochief( thread ) ) \ + { \ + /* The chief thread releases the existing block associated + with the mem_t, and then re-acquires a new block, saving + the associated mem_t to its passed-in mem_t. (See coment + above for why the acquisition needs to be directly to + the chief thread's passed-in mem_t and not a local + (temporary) mem_t. */ \ + bli_membrk_release \ + ( \ + rntm, \ + mem \ + ); \ + bli_membrk_acquire_m \ + ( \ + rntm, \ + size_needed, \ + pack_buf_type, \ + mem \ + ); \ + } \ +\ + /* Broadcast the address of the chief thread's passed-in mem_t + to all threads. */ \ + mem_t* mem_p = bli_thread_broadcast( thread, mem ); \ +\ + /* Non-chief threads: Copy the contents of the chief thread's + passed-in mem_t to the passed-in mem_t for this thread. (The + chief thread already has the mem_t, so it does not need to + perform any copy.) */ \ + if ( !bli_thread_am_ochief( thread ) ) \ + { \ + *mem = *mem_p; \ + } \ + } \ + else \ + { \ + /* If the mem_t entry is already allocated and sufficiently large, + then we use it as-is. No action is needed. */ \ + } \ + } \ +} + +//INSERT_GENTFUNC_BASIC0( packm_init_mem_a ) +GENTFUNC( float, s, packm_init_mem_a ) +GENTFUNC( double, d, packm_init_mem_a ) +GENTFUNC( scomplex, c, packm_init_mem_a ) +GENTFUNC( dcomplex, z, packm_init_mem_a ) + + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, opname ) \ +\ +void PASTECH2(bls_,ch,opname) \ + ( \ + rntm_t* restrict rntm, \ + mem_t* restrict mem, \ + thrinfo_t* restrict thread \ + ) \ +{ \ + if ( thread != NULL ) \ + if ( bli_thread_am_ochief( thread ) ) \ + { \ + /* Check the mem_t entry provided by the caller. Only proceed if it + is allocated, which it should be. */ \ + if ( bli_mem_is_alloc( mem ) ) \ + { \ + bli_membrk_release \ + ( \ + rntm, \ + mem \ + ); \ + } \ + } \ +} + +//INSERT_GENTFUNC_BASIC0( packm_finalize_mem_a ) +GENTFUNC( float, s, packm_finalize_mem_a ) +GENTFUNC( double, d, packm_finalize_mem_a ) +GENTFUNC( scomplex, c, packm_finalize_mem_a ) +GENTFUNC( dcomplex, z, packm_finalize_mem_a ) + + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, opname ) \ +\ +void PASTECH2(bls_,ch,opname) \ + ( \ + pack_t* restrict schema, \ + dim_t m, \ + dim_t k, \ + dim_t mr, \ + dim_t* restrict m_max, \ + dim_t* restrict k_max, \ + ctype** p, inc_t* restrict rs_p, inc_t* restrict cs_p, \ + dim_t* restrict pd_p, inc_t* restrict ps_p, \ + mem_t* restrict mem \ + ) \ +{ \ + /* NOTE: This "rounding up" of the last upanel is absolutely necessary since + we NEED that last micropanel to have the same ldim (cs_p) as the other + micropanels. Why? Because the microkernel assumes that the register (MR, + NR) AND storage (PACKMR, PACKNR) blocksizes do not change. */ \ + *m_max = ( m / mr + ( m % mr ? 1 : 0 ) ) * mr; \ + *k_max = k; \ +\ + /* Determine the dimensions and strides for the packed matrix A. */ \ + { \ + /* Pack A to column-stored row-panels. */ \ + *rs_p = 1; \ + *cs_p = mr; \ +\ + *pd_p = mr; \ + *ps_p = mr * k; \ +\ + /* Set the schema to "packed row panels" to indicate packing to + conventional column-stored row panels. */ \ + *schema = BLIS_PACKED_ROW_PANELS; \ + } \ +\ + /* Set the buffer address provided by the caller to point to the memory + associated with the mem_t entry acquired from the memory pool. */ \ + *p = bli_mem_buffer( mem ); \ +} + +//INSERT_GENTFUNC_BASIC0( packm_init_a ) +GENTFUNC( float, s, packm_init_a ) +GENTFUNC( double, d, packm_init_a ) +GENTFUNC( scomplex, c, packm_init_a ) +GENTFUNC( dcomplex, z, packm_init_a ) + + +// +// Define BLAS-like interfaces to the variant chooser. +// + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, opname ) \ +\ +void PASTECH2(bls_,ch,opname) \ + ( \ + conj_t conj, \ + dim_t m_alloc, \ + dim_t k_alloc, \ + dim_t m, \ + dim_t k, \ + dim_t mr, \ + ctype* restrict kappa, \ + ctype* restrict a, inc_t rs_a, inc_t cs_a, \ + ctype** restrict p, inc_t* restrict rs_p, inc_t* restrict cs_p, \ + inc_t* restrict ps_p, \ + cntx_t* restrict cntx, \ + rntm_t* restrict rntm, \ + mem_t* restrict mem, \ + thrinfo_t* restrict thread \ + ) \ +{ \ + pack_t schema; \ + dim_t m_max; \ + dim_t k_max; \ + dim_t pd_p; \ +\ + /* Prepare the packing destination buffer. */ \ + PASTECH2(bls_,ch,packm_init_mem_a) \ + ( \ + m_alloc, k_alloc, mr, \ + cntx, \ + rntm, \ + mem, \ + thread \ + ); \ +\ + /* Determine the packing buffer and related parameters for matrix A. */ \ + PASTECH2(bls_,ch,packm_init_a) \ + ( \ + &schema, \ + m, k, mr, \ + &m_max, &k_max, \ + p, rs_p, cs_p, \ + &pd_p, ps_p, \ + mem \ + ); \ +\ + /* Pack matrix A to the destination buffer chosen above. Here, the packed + matrix is stored to column-stored MR x k micropanels. */ \ + PASTECH2(bls_,ch,packm_var1) \ + ( \ + conj, \ + schema, \ + m, \ + k, \ + m_max, \ + k_max, \ + kappa, \ + a, rs_a, cs_a, \ + *p, *rs_p, *cs_p, \ + pd_p, *ps_p, \ + cntx, \ + thread \ + ); \ +\ + /* Barrier so that packing is done before computation. */ \ + bli_thread_barrier( thread ); \ +} + +//INSERT_GENTFUNC_BASIC0( packm_a ) +GENTFUNC( float, s, packm_a ) +GENTFUNC( double, d, packm_a ) +GENTFUNC( scomplex, c, packm_a ) +GENTFUNC( dcomplex, z, packm_a ) + diff --git a/sandbox/gemmlike/bls_l3_packm_a.h b/sandbox/gemmlike/bls_l3_packm_a.h new file mode 100644 index 0000000000..201a24efae --- /dev/null +++ b/sandbox/gemmlike/bls_l3_packm_a.h @@ -0,0 +1,122 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#undef GENTPROT +#define GENTPROT( ctype, ch, opname ) \ +\ +void PASTECH2(bls_,ch,opname) \ + ( \ + dim_t m, \ + dim_t k, \ + dim_t mr, \ + cntx_t* restrict cntx, \ + rntm_t* restrict rntm, \ + mem_t* restrict mem, \ + thrinfo_t* restrict thread \ + ); \ + +//INSERT_GENTPROT_BASIC0( packm_init_mem_a ) +GENTPROT( float, s, packm_init_mem_a ) +GENTPROT( double, d, packm_init_mem_a ) +GENTPROT( scomplex, c, packm_init_mem_a ) +GENTPROT( dcomplex, z, packm_init_mem_a ) + + +#undef GENTPROT +#define GENTPROT( ctype, ch, opname ) \ +\ +void PASTECH2(bls_,ch,opname) \ + ( \ + rntm_t* restrict rntm, \ + mem_t* restrict mem, \ + thrinfo_t* restrict thread \ + ); \ + +//INSERT_GENTPROT_BASIC0( packm_finalize_mem_a ) +GENTPROT( float, s, packm_finalize_mem_a ) +GENTPROT( double, d, packm_finalize_mem_a ) +GENTPROT( scomplex, c, packm_finalize_mem_a ) +GENTPROT( dcomplex, z, packm_finalize_mem_a ) + + +#undef GENTPROT +#define GENTPROT( ctype, ch, opname ) \ +\ +void PASTECH2(bls_,ch,opname) \ + ( \ + pack_t* restrict schema, \ + dim_t m, \ + dim_t k, \ + dim_t mr, \ + dim_t* restrict m_max, \ + dim_t* restrict k_max, \ + ctype** p, inc_t* restrict rs_p, inc_t* restrict cs_p, \ + dim_t* restrict pd_p, inc_t* restrict ps_p, \ + mem_t* restrict mem \ + ); \ + +//INSERT_GENTPROT_BASIC0( packm_init_a ) +GENTPROT( float, s, packm_init_a ) +GENTPROT( double, d, packm_init_a ) +GENTPROT( scomplex, c, packm_init_a ) +GENTPROT( dcomplex, z, packm_init_a ) + + +#undef GENTPROT +#define GENTPROT( ctype, ch, opname ) \ +\ +void PASTECH2(bls_,ch,opname) \ + ( \ + conj_t conj, \ + dim_t m_alloc, \ + dim_t k_alloc, \ + dim_t m, \ + dim_t k, \ + dim_t mr, \ + ctype* restrict kappa, \ + ctype* restrict a, inc_t rs_a, inc_t cs_a, \ + ctype** restrict p, inc_t* restrict rs_p, inc_t* restrict cs_p, \ + inc_t* restrict ps_p, \ + cntx_t* restrict cntx, \ + rntm_t* restrict rntm, \ + mem_t* restrict mem, \ + thrinfo_t* restrict thread \ + ); \ + +//INSERT_GENTPROT_BASIC0( packm_a ) +GENTPROT( float, s, packm_a ) +GENTPROT( double, d, packm_a ) +GENTPROT( scomplex, c, packm_a ) +GENTPROT( dcomplex, z, packm_a ) + diff --git a/sandbox/gemmlike/bls_l3_packm_b.c b/sandbox/gemmlike/bls_l3_packm_b.c new file mode 100644 index 0000000000..cae93df012 --- /dev/null +++ b/sandbox/gemmlike/bls_l3_packm_b.c @@ -0,0 +1,328 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, opname ) \ +\ +void PASTECH2(bls_,ch,opname) \ + ( \ + dim_t k, \ + dim_t n, \ + dim_t nr, \ + cntx_t* restrict cntx, \ + rntm_t* restrict rntm, \ + mem_t* restrict mem, \ + thrinfo_t* restrict thread \ + ) \ +{ \ + /* Set the pack buffer type so that we are obtaining memory blocks from + the pool dedicated to panels of B. */ \ + const packbuf_t pack_buf_type = BLIS_BUFFER_FOR_B_PANEL; \ +\ + /* NOTE: This "rounding up" of the last upanel is absolutely necessary since + we NEED that last micropanel to have the same ldim (cs_p) as the other + micropanels. Why? Because the microkernel assumes that the register (MR, + NR) AND storage (PACKMR, PACKNR) blocksizes do not change. */ \ + const dim_t k_pack = k; \ + const dim_t n_pack = ( n / nr + ( n % nr ? 1 : 0 ) ) * nr; \ +\ + /* Barrier to make sure all threads are caught up and ready to begin the + packm stage. */ \ + bli_thread_barrier( thread ); \ +\ + /* Compute the size of the memory block eneded. */ \ + siz_t size_needed = sizeof( ctype ) * k_pack * n_pack; \ +\ + /* Check the mem_t entry provided by the caller. If it is unallocated, + then we need to acquire a block from the memory broker. */ \ + if ( bli_mem_is_unalloc( mem ) ) \ + { \ + if ( bli_thread_am_ochief( thread ) ) \ + { \ + /* Acquire directly to the chief thread's mem_t that was passed in. + It needs to be that mem_t struct, and not a local (temporary) + mem_t, since there is no barrier until after packing is finished, + which could allow a race condition whereby the chief thread exits + the current function before the other threads have a chance to + copy from it. (A barrier would fix that race condition, but then + again, I prefer to keep barriers to a minimum.) */ \ + bli_membrk_acquire_m \ + ( \ + rntm, \ + size_needed, \ + pack_buf_type, \ + mem \ + ); \ + } \ +\ + /* Broadcast the address of the chief thread's passed-in mem_t to all + threads. */ \ + mem_t* mem_p = bli_thread_broadcast( thread, mem ); \ +\ + /* Non-chief threads: Copy the contents of the chief thread's + passed-in mem_t to the passed-in mem_t for this thread. (The + chief thread already has the mem_t, so it does not need to + perform any copy.) */ \ + if ( !bli_thread_am_ochief( thread ) ) \ + { \ + *mem = *mem_p; \ + } \ + } \ + else /* if ( bli_mem_is_alloc( mem ) ) */ \ + { \ + /* If the mem_t entry provided by the caller does NOT contain a NULL + buffer, then a block has already been acquired from the memory + broker and cached by the caller. */ \ +\ + /* As a sanity check, we should make sure that the mem_t object isn't + associated with a block that is too small compared to the size of + the packed matrix buffer that is needed, according to the value + computed above. */ \ + siz_t mem_size = bli_mem_size( mem ); \ +\ + if ( mem_size < size_needed ) \ + { \ + if ( bli_thread_am_ochief( thread ) ) \ + { \ + /* The chief thread releases the existing block associated + with the mem_t, and then re-acquires a new block, saving + the associated mem_t to its passed-in mem_t. (See coment + above for why the acquisition needs to be directly to + the chief thread's passed-in mem_t and not a local + (temporary) mem_t. */ \ + bli_membrk_release \ + ( \ + rntm, \ + mem \ + ); \ + bli_membrk_acquire_m \ + ( \ + rntm, \ + size_needed, \ + pack_buf_type, \ + mem \ + ); \ + } \ +\ + /* Broadcast the address of the chief thread's passed-in mem_t + to all threads. */ \ + mem_t* mem_p = bli_thread_broadcast( thread, mem ); \ +\ + /* Non-chief threads: Copy the contents of the chief thread's + passed-in mem_t to the passed-in mem_t for this thread. (The + chief thread already has the mem_t, so it does not need to + perform any copy.) */ \ + if ( !bli_thread_am_ochief( thread ) ) \ + { \ + *mem = *mem_p; \ + } \ + } \ + else \ + { \ + /* If the mem_t entry is already allocated and sufficiently large, + then we use it as-is. No action is needed. */ \ + } \ + } \ +} + +//INSERT_GENTFUNC_BASIC0( packm_init_mem_b ) +GENTFUNC( float, s, packm_init_mem_b ) +GENTFUNC( double, d, packm_init_mem_b ) +GENTFUNC( scomplex, c, packm_init_mem_b ) +GENTFUNC( dcomplex, z, packm_init_mem_b ) + + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, opname ) \ +\ +void PASTECH2(bls_,ch,opname) \ + ( \ + rntm_t* restrict rntm, \ + mem_t* restrict mem, \ + thrinfo_t* restrict thread \ + ) \ +{ \ + if ( thread != NULL ) \ + if ( bli_thread_am_ochief( thread ) ) \ + { \ + /* Check the mem_t entry provided by the caller. Only proceed if it + is allocated, which it should be. */ \ + if ( bli_mem_is_alloc( mem ) ) \ + { \ + bli_membrk_release \ + ( \ + rntm, \ + mem \ + ); \ + } \ + } \ +} + +//INSERT_GENTFUNC_BASIC0( packm_finalize_mem_b ) +GENTFUNC( float, s, packm_finalize_mem_b ) +GENTFUNC( double, d, packm_finalize_mem_b ) +GENTFUNC( scomplex, c, packm_finalize_mem_b ) +GENTFUNC( dcomplex, z, packm_finalize_mem_b ) + + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, opname ) \ +\ +void PASTECH2(bls_,ch,opname) \ + ( \ + pack_t* restrict schema, \ + dim_t k, \ + dim_t n, \ + dim_t nr, \ + dim_t* restrict k_max, \ + dim_t* restrict n_max, \ + ctype** p, inc_t* restrict rs_p, inc_t* restrict cs_p, \ + dim_t* restrict pd_p, inc_t* restrict ps_p, \ + mem_t* restrict mem \ + ) \ +{ \ + /* NOTE: This "rounding up" of the last upanel is absolutely necessary since + we NEED that last micropanel to have the same ldim (cs_p) as the other + micropanels. Why? Because the microkernel assumes that the register (MR, + NR) AND storage (PACKMR, PACKNR) blocksizes do not change. */ \ + *k_max = k; \ + *n_max = ( n / nr + ( n % nr ? 1 : 0 ) ) * nr; \ +\ + /* Determine the dimensions and strides for the packed matrix B. */ \ + { \ + /* Pack B to row-stored column-panels. */ \ + *rs_p = nr; \ + *cs_p = 1; \ +\ + *pd_p = nr; \ + *ps_p = k * nr; \ +\ + /* Set the schema to "packed column panels" to indicate packing to + conventional row-stored column panels. */ \ + *schema = BLIS_PACKED_COL_PANELS; \ + } \ +\ + /* Set the buffer address provided by the caller to point to the memory + associated with the mem_t entry acquired from the memory pool. */ \ + *p = bli_mem_buffer( mem ); \ +} + +//INSERT_GENTFUNC_BASIC0( packm_init_b ) +GENTFUNC( float, s, packm_init_b ) +GENTFUNC( double, d, packm_init_b ) +GENTFUNC( scomplex, c, packm_init_b ) +GENTFUNC( dcomplex, z, packm_init_b ) + + +// +// Define BLAS-like interfaces to the variant chooser. +// + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, opname ) \ +\ +void PASTECH2(bls_,ch,opname) \ + ( \ + conj_t conj, \ + dim_t k_alloc, \ + dim_t n_alloc, \ + dim_t k, \ + dim_t n, \ + dim_t nr, \ + ctype* restrict kappa, \ + ctype* restrict b, inc_t rs_b, inc_t cs_b, \ + ctype** restrict p, inc_t* restrict rs_p, inc_t* restrict cs_p, \ + inc_t* restrict ps_p, \ + cntx_t* restrict cntx, \ + rntm_t* restrict rntm, \ + mem_t* restrict mem, \ + thrinfo_t* restrict thread \ + ) \ +{ \ + pack_t schema; \ + dim_t k_max; \ + dim_t n_max; \ + dim_t pd_p; \ +\ + /* Prepare the packing destination buffer. */ \ + PASTECH2(bls_,ch,packm_init_mem_b) \ + ( \ + k_alloc, n_alloc, nr, \ + cntx, \ + rntm, \ + mem, \ + thread \ + ); \ +\ + /* Determine the packing buffer and related parameters for matrix B. */ \ + PASTECH2(bls_,ch,packm_init_b) \ + ( \ + &schema, \ + k, n, nr, \ + &k_max, &n_max, \ + p, rs_p, cs_p, \ + &pd_p, ps_p, \ + mem \ + ); \ +\ + /* Pack matrix B to the destination buffer chosen above. Here, the packed + matrix is stored to row-stored k x NR micropanels. */ \ + PASTECH2(bls_,ch,packm_var1) \ + ( \ + conj, \ + schema, \ + k, \ + n, \ + k_max, \ + n_max, \ + kappa, \ + b, rs_b, cs_b, \ + *p, *rs_p, *cs_p, \ + pd_p, *ps_p, \ + cntx, \ + thread \ + ); \ +\ + /* Barrier so that packing is done before computation. */ \ + bli_thread_barrier( thread ); \ +} + +//INSERT_GENTFUNC_BASIC0( packm_b ) +GENTFUNC( float, s, packm_b ) +GENTFUNC( double, d, packm_b ) +GENTFUNC( scomplex, c, packm_b ) +GENTFUNC( dcomplex, z, packm_b ) + diff --git a/sandbox/gemmlike/bls_l3_packm_b.h b/sandbox/gemmlike/bls_l3_packm_b.h new file mode 100644 index 0000000000..728d21aed5 --- /dev/null +++ b/sandbox/gemmlike/bls_l3_packm_b.h @@ -0,0 +1,122 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#undef GENTPROT +#define GENTPROT( ctype, ch, opname ) \ +\ +void PASTECH2(bls_,ch,opname) \ + ( \ + dim_t k, \ + dim_t n, \ + dim_t nr, \ + cntx_t* restrict cntx, \ + rntm_t* restrict rntm, \ + mem_t* restrict mem, \ + thrinfo_t* restrict thread \ + ); \ + +//INSERT_GENTPROT_BASIC0( packm_init_mem_b ) +GENTPROT( float, s, packm_init_mem_b ) +GENTPROT( double, d, packm_init_mem_b ) +GENTPROT( scomplex, c, packm_init_mem_b ) +GENTPROT( dcomplex, z, packm_init_mem_b ) + + +#undef GENTPROT +#define GENTPROT( ctype, ch, opname ) \ +\ +void PASTECH2(bls_,ch,opname) \ + ( \ + rntm_t* restrict rntm, \ + mem_t* restrict mem, \ + thrinfo_t* restrict thread \ + ); \ + +//INSERT_GENTPROT_BASIC0( packm_finalize_mem_b ) +GENTPROT( float, s, packm_finalize_mem_b ) +GENTPROT( double, d, packm_finalize_mem_b ) +GENTPROT( scomplex, c, packm_finalize_mem_b ) +GENTPROT( dcomplex, z, packm_finalize_mem_b ) + + +#undef GENTPROT +#define GENTPROT( ctype, ch, opname ) \ +\ +void PASTECH2(bls_,ch,opname) \ + ( \ + pack_t* restrict schema, \ + dim_t k, \ + dim_t n, \ + dim_t nr, \ + dim_t* restrict k_max, \ + dim_t* restrict n_max, \ + ctype** p, inc_t* restrict rs_p, inc_t* restrict cs_p, \ + dim_t* restrict pd_p, inc_t* restrict ps_p, \ + mem_t* restrict mem \ + ); \ + +//INSERT_GENTPROT_BASIC0( packm_init_b ) +GENTPROT( float, s, packm_init_b ) +GENTPROT( double, d, packm_init_b ) +GENTPROT( scomplex, c, packm_init_b ) +GENTPROT( dcomplex, z, packm_init_b ) + + +#undef GENTPROT +#define GENTPROT( ctype, ch, opname ) \ +\ +void PASTECH2(bls_,ch,opname) \ + ( \ + conj_t conj, \ + dim_t k_alloc, \ + dim_t n_alloc, \ + dim_t k, \ + dim_t n, \ + dim_t nr, \ + ctype* restrict kappa, \ + ctype* restrict b, inc_t rs_b, inc_t cs_b, \ + ctype** restrict p, inc_t* restrict rs_p, inc_t* restrict cs_p, \ + inc_t* restrict ps_p, \ + cntx_t* restrict cntx, \ + rntm_t* restrict rntm, \ + mem_t* restrict mem, \ + thrinfo_t* restrict thread \ + ); \ + +//INSERT_GENTPROT_BASIC0( packm_b ) +GENTPROT( float, s, packm_b ) +GENTPROT( double, d, packm_b ) +GENTPROT( scomplex, c, packm_b ) +GENTPROT( dcomplex, z, packm_b ) + diff --git a/sandbox/gemmlike/bls_l3_packm_var.c b/sandbox/gemmlike/bls_l3_packm_var.c new file mode 100644 index 0000000000..3265ef834d --- /dev/null +++ b/sandbox/gemmlike/bls_l3_packm_var.c @@ -0,0 +1,198 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +// +// Define BLAS-like interfaces to the variants. +// + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, varname ) \ +\ +void PASTECH2(bls_,ch,varname) \ + ( \ + trans_t transc, \ + pack_t schema, \ + dim_t m, \ + dim_t n, \ + dim_t m_max, \ + dim_t n_max, \ + ctype* restrict kappa, \ + ctype* restrict c, inc_t rs_c, inc_t cs_c, \ + ctype* restrict p, inc_t rs_p, inc_t cs_p, \ + dim_t pd_p, inc_t ps_p, \ + cntx_t* restrict cntx, \ + thrinfo_t* restrict thread \ + ) \ +{ \ + ctype* restrict kappa_cast = kappa; \ + ctype* restrict c_cast = c; \ + ctype* restrict p_cast = p; \ +\ + dim_t iter_dim; \ + dim_t n_iter; \ + dim_t it, ic; \ + dim_t ic0; \ + doff_t ic_inc; \ + dim_t panel_len_full; \ + dim_t panel_len_i; \ + dim_t panel_len_max; \ + dim_t panel_len_max_i; \ + dim_t panel_dim_i; \ + dim_t panel_dim_max; \ + inc_t vs_c; \ + inc_t ldc; \ + inc_t ldp; \ + conj_t conjc; \ +\ +\ + /* Extract the conjugation bit from the transposition argument. */ \ + conjc = bli_extract_conj( transc ); \ +\ + /* Create flags to incidate row or column storage. Note that the + schema bit that encodes row or column is describing the form of + micro-panel, not the storage in the micro-panel. Hence the + mismatch in "row" and "column" semantics. */ \ + bool row_stored = bli_is_col_packed( schema ); \ + /*bool col_stored = bli_is_row_packed( schema );*/ \ +\ + /* If the row storage flag indicates row storage, then we are packing + to column panels; otherwise, if the strides indicate column storage, + we are packing to row panels. */ \ + if ( row_stored ) \ + { \ + /* Prepare to pack to row-stored column panels. */ \ + iter_dim = n; \ + panel_len_full = m; \ + panel_len_max = m_max; \ + panel_dim_max = pd_p; \ + vs_c = cs_c; \ + ldc = rs_c; \ + ldp = rs_p; \ + } \ + else /* if ( col_stored ) */ \ + { \ + /* Prepare to pack to column-stored row panels. */ \ + iter_dim = m; \ + panel_len_full = n; \ + panel_len_max = n_max; \ + panel_dim_max = pd_p; \ + vs_c = rs_c; \ + ldc = cs_c; \ + ldp = cs_p; \ + } \ +\ + /* Compute the total number of iterations we'll need. */ \ + n_iter = iter_dim / panel_dim_max + ( iter_dim % panel_dim_max ? 1 : 0 ); \ +\ + /* Set the initial values and increments for indices related to C and P + based on whether reverse iteration was requested. */ \ + { \ + ic0 = 0; \ + ic_inc = panel_dim_max; \ + } \ +\ + ctype* restrict p_begin = p_cast; \ +\ + /* Query the number of threads and thread ids from the current thread's + packm thrinfo_t node. */ \ + const dim_t nt = bli_thread_n_way( thread ); \ + const dim_t tid = bli_thread_work_id( thread ); \ +\ + /* Suppress warnings in case tid isn't used (ie: as in slab partitioning). */ \ + ( void )nt; \ + ( void )tid; \ +\ + dim_t it_start, it_end, it_inc; \ +\ + /* Determine the thread range and increment using the current thread's + packm thrinfo_t node. NOTE: The definition of bli_thread_range_jrir() + will depend on whether slab or round-robin partitioning was requested + at configure-time. */ \ + bli_thread_range_jrir( thread, n_iter, 1, FALSE, &it_start, &it_end, &it_inc ); \ +\ + /* Iterate over every logical micropanel in the source matrix. */ \ + for ( ic = ic0, it = 0; it < n_iter; \ + ic += ic_inc, it += 1 ) \ + { \ + panel_dim_i = bli_min( panel_dim_max, iter_dim - ic ); \ +\ + ctype* restrict c_begin = c_cast + (ic )*vs_c; \ +\ + ctype* restrict c_use = c_begin; \ + ctype* restrict p_use = p_begin; \ +\ + panel_len_i = panel_len_full; \ + panel_len_max_i = panel_len_max; \ +\ + /* The definition of bli_packm_my_iter() will depend on whether slab + or round-robin partitioning was requested at configure-time. (The + default is slab.) */ \ + if ( bli_packm_my_iter( it, it_start, it_end, tid, nt ) ) \ + { \ + PASTEMAC(ch,packm_cxk) \ + ( \ + conjc, \ + schema, \ + panel_dim_i, \ + panel_dim_max, \ + panel_len_i, \ + panel_len_max_i, \ + kappa_cast, \ + c_use, vs_c, ldc, \ + p_use, ldp, \ + cntx \ + ); \ + } \ +\ +/* +if ( !row_stored ) \ +PASTEMAC(ch,fprintm)( stdout, "packm_var1: a packed", panel_dim_max, panel_len_max, \ + p_use, rs_p, cs_p, "%5.2f", "" ); \ +else \ +PASTEMAC(ch,fprintm)( stdout, "packm_var1: b packed", panel_len_max, panel_dim_max, \ + p_use, rs_p, cs_p, "%5.2f", "" ); \ +*/ \ +\ + p_begin += ps_p; \ + } \ +} + +//INSERT_GENTFUNC_BASIC0( packm_var1 ) +GENTFUNC( float, s, packm_var1 ) +GENTFUNC( double, d, packm_var1 ) +GENTFUNC( scomplex, c, packm_var1 ) +GENTFUNC( dcomplex, z, packm_var1 ) + diff --git a/sandbox/gemmlike/bls_l3_packm_var.h b/sandbox/gemmlike/bls_l3_packm_var.h new file mode 100644 index 0000000000..0e8eb9ee8a --- /dev/null +++ b/sandbox/gemmlike/bls_l3_packm_var.h @@ -0,0 +1,63 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +// +// Prototype BLAS-like interfaces to the variants. +// + +#undef GENTPROT +#define GENTPROT( ctype, ch, varname ) \ +\ +void PASTECH2(bls_,ch,varname) \ + ( \ + trans_t transc, \ + pack_t schema, \ + dim_t m, \ + dim_t n, \ + dim_t m_max, \ + dim_t n_max, \ + ctype* restrict kappa, \ + ctype* restrict c, inc_t rs_c, inc_t cs_c, \ + ctype* restrict p, inc_t rs_p, inc_t cs_p, \ + dim_t pd_p, inc_t ps_p, \ + cntx_t* restrict cntx, \ + thrinfo_t* restrict thread \ + ); + +//INSERT_GENTPROT_BASIC0( packm_var1 ) +GENTPROT( float, s, packm_var1 ) +GENTPROT( double, d, packm_var1 ) +GENTPROT( scomplex, c, packm_var1 ) +GENTPROT( dcomplex, z, packm_var1 ) + diff --git a/sandbox/gemmlike/thread/bls_l3_decor.h b/sandbox/gemmlike/thread/bls_l3_decor.h new file mode 100644 index 0000000000..bb8a95bb46 --- /dev/null +++ b/sandbox/gemmlike/thread/bls_l3_decor.h @@ -0,0 +1,73 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef BLIS_SBX_L3_DECOR_H +#define BLIS_SBX_L3_DECOR_H + +// -- sup definitions ---------------------------------------------------------- + +// Level-3 sup internal function type. +typedef void (*l3sbxint_t) + ( + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + rntm_t* rntm, + thrinfo_t* thread + ); + +// Level-3 sup thread decorator prototype. +void bls_l3_thread_decorator + ( + l3sbxint_t func, + opid_t family, + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + rntm_t* rntm + ); + +// Include definitions specific to the method of multithreading. +#include "bls_l3_decor_single.h" +#include "bls_l3_decor_openmp.h" +#include "bls_l3_decor_pthreads.h" + +#endif + diff --git a/sandbox/gemmlike/thread/bls_l3_decor_openmp.c b/sandbox/gemmlike/thread/bls_l3_decor_openmp.c new file mode 100644 index 0000000000..851a29e52b --- /dev/null +++ b/sandbox/gemmlike/thread/bls_l3_decor_openmp.c @@ -0,0 +1,138 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#ifdef BLIS_ENABLE_OPENMP + +// Define a dummy thread entry function, which is needed in the pthreads +// version, so that when building Windows DLLs (with OpenMP enabled or with +// no multithreading) we don't risk having an unresolved symbol. +void* bls_l3_thread_entry( void* data_void ) { return NULL; } + +//#define PRINT_THRINFO + +void bls_l3_thread_decorator + ( + l3sbxint_t func, + opid_t family, + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + rntm_t* rntm + ) +{ + // Query the total number of threads from the rntm_t object. + const dim_t n_threads = bli_rntm_num_threads( rntm ); + + // NOTE: The sba was initialized in bli_init(). + + // Check out an array_t from the small block allocator. This is done + // with an internal lock to ensure only one application thread accesses + // the sba at a time. bli_sba_checkout_array() will also automatically + // resize the array_t, if necessary. + array_t* restrict array = bli_sba_checkout_array( n_threads ); + + // Access the pool_t* for thread 0 and embed it into the rntm. We do + // this up-front only so that we have the rntm_t.sba_pool field + // initialized and ready for the global communicator creation below. + bli_sba_rntm_set_pool( 0, array, rntm ); + + // Set the packing block allocator field of the rntm. This will be + // inherited by all of the child threads when they make local copies of + // the rntm below. + bli_membrk_rntm_set_membrk( rntm ); + + // Allcoate a global communicator for the root thrinfo_t structures. + thrcomm_t* restrict gl_comm = bli_thrcomm_create( rntm, n_threads ); + + + _Pragma( "omp parallel num_threads(n_threads)" ) + { + // Create a thread-local copy of the master thread's rntm_t. This is + // necessary since we want each thread to be able to track its own + // small block pool_t as it executes down the function stack. + rntm_t rntm_l = *rntm; + rntm_t* restrict rntm_p = &rntm_l; + + // Query the thread's id from OpenMP. + const dim_t tid = omp_get_thread_num(); + + // Check for a somewhat obscure OpenMP thread-mistmatch issue. + // NOTE: This calls the same function used for the conventional/large + // code path. + bli_l3_thread_decorator_thread_check( n_threads, tid, gl_comm, rntm_p ); + + // Use the thread id to access the appropriate pool_t* within the + // array_t, and use it to set the sba_pool field within the rntm_t. + // If the pool_t* element within the array_t is NULL, it will first + // be allocated/initialized. + bli_sba_rntm_set_pool( tid, array, rntm_p ); + + thrinfo_t* thread = NULL; + + // Create the root node of the thread's thrinfo_t structure. + bli_l3_sup_thrinfo_create_root( tid, gl_comm, rntm_p, &thread ); + + func + ( + alpha, + a, + b, + beta, + c, + cntx, + rntm_p, + thread + ); + + // Free the current thread's thrinfo_t structure. + bli_l3_sup_thrinfo_free( rntm_p, thread ); + } + + // We shouldn't free the global communicator since it was already freed + // by the global communicator's chief thread in bli_l3_thrinfo_free() + // (called from the thread entry function). + + // Check the array_t back into the small block allocator. Similar to the + // check-out, this is done using a lock embedded within the sba to ensure + // mutual exclusion. + bli_sba_checkin_array( array ); +} + +#endif + diff --git a/sandbox/gemmlike/thread/bls_l3_decor_openmp.h b/sandbox/gemmlike/thread/bls_l3_decor_openmp.h new file mode 100644 index 0000000000..9c956d7c36 --- /dev/null +++ b/sandbox/gemmlike/thread/bls_l3_decor_openmp.h @@ -0,0 +1,44 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef BLIS_SBX_L3_DECOR_OPENMP_H +#define BLIS_SBX_L3_DECOR_OPENMP_H + +// Definitions specific to situations when OpenMP multithreading is enabled. +#ifdef BLIS_ENABLE_OPENMP + +#endif + +#endif + diff --git a/sandbox/gemmlike/thread/bls_l3_decor_pthreads.c b/sandbox/gemmlike/thread/bls_l3_decor_pthreads.c new file mode 100644 index 0000000000..f87d79fd6c --- /dev/null +++ b/sandbox/gemmlike/thread/bls_l3_decor_pthreads.c @@ -0,0 +1,213 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#ifdef BLIS_ENABLE_PTHREADS + +// A data structure to assist in passing operands to additional threads. +typedef struct thread_data +{ + l3sbxint_t func; + opid_t family; + obj_t* alpha; + obj_t* a; + obj_t* b; + obj_t* beta; + obj_t* c; + cntx_t* cntx; + rntm_t* rntm; + dim_t tid; + thrcomm_t* gl_comm; + array_t* array; +} thread_data_t; + +// Entry point function for additional threads. +void* bls_l3_thread_entry( void* data_void ) +{ + thread_data_t* data = data_void; + + l3sbxint_t func = data->func; + opid_t family = data->family; + obj_t* alpha = data->alpha; + obj_t* a = data->a; + obj_t* b = data->b; + obj_t* beta = data->beta; + obj_t* c = data->c; + cntx_t* cntx = data->cntx; + rntm_t* rntm = data->rntm; + dim_t tid = data->tid; + array_t* array = data->array; + thrcomm_t* gl_comm = data->gl_comm; + + ( void )family; + + // Create a thread-local copy of the master thread's rntm_t. This is + // necessary since we want each thread to be able to track its own + // small block pool_t as it executes down the function stack. + rntm_t rntm_l = *rntm; + rntm_t* restrict rntm_p = &rntm_l; + + // Use the thread id to access the appropriate pool_t* within the + // array_t, and use it to set the sba_pool field within the rntm_t. + // If the pool_t* element within the array_t is NULL, it will first + // be allocated/initialized. + bli_sba_rntm_set_pool( tid, array, rntm_p ); + + thrinfo_t* thread = NULL; + + // Create the root node of the current thread's thrinfo_t structure. + bli_l3_sup_thrinfo_create_root( tid, gl_comm, rntm_p, &thread ); + + func + ( + alpha, + a, + b, + beta, + c, + cntx, + rntm_p, + thread + ); + + // Free the current thread's thrinfo_t structure. + bli_l3_sup_thrinfo_free( rntm_p, thread ); + + return NULL; +} + +void bls_l3_thread_decorator + ( + l3sbxint_t func, + opid_t family, + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + rntm_t* rntm + ) +{ + // Query the total number of threads from the context. + const dim_t n_threads = bli_rntm_num_threads( rntm ); + + // NOTE: The sba was initialized in bli_init(). + + // Check out an array_t from the small block allocator. This is done + // with an internal lock to ensure only one application thread accesses + // the sba at a time. bli_sba_checkout_array() will also automatically + // resize the array_t, if necessary. + array_t* restrict array = bli_sba_checkout_array( n_threads ); + + // Access the pool_t* for thread 0 and embed it into the rntm. We do + // this up-front only so that we have the rntm_t.sba_pool field + // initialized and ready for the global communicator creation below. + bli_sba_rntm_set_pool( 0, array, rntm ); + + // Set the packing block allocator field of the rntm. This will be + // inherited by all of the child threads when they make local copies of + // the rntm below. + bli_membrk_rntm_set_membrk( rntm ); + + // Allocate a global communicator for the root thrinfo_t structures. + thrcomm_t* restrict gl_comm = bli_thrcomm_create( rntm, n_threads ); + + // Allocate an array of pthread objects and auxiliary data structs to pass + // to the thread entry functions. + + #ifdef BLIS_ENABLE_MEM_TRACING + printf( "bli_l3_thread_decorator().pth: " ); + #endif + bli_pthread_t* pthreads = bli_malloc_intl( sizeof( bli_pthread_t ) * n_threads ); + + #ifdef BLIS_ENABLE_MEM_TRACING + printf( "bli_l3_thread_decorator().pth: " ); + #endif + thread_data_t* datas = bli_malloc_intl( sizeof( thread_data_t ) * n_threads ); + + // NOTE: We must iterate backwards so that the chief thread (thread id 0) + // can spawn all other threads before proceeding with its own computation. + for ( dim_t tid = n_threads - 1; 0 <= tid; tid-- ) + { + // Set up thread data for additional threads (beyond thread 0). + datas[tid].func = func; + datas[tid].family = family; + datas[tid].alpha = alpha; + datas[tid].a = a; + datas[tid].b = b; + datas[tid].beta = beta; + datas[tid].c = c; + datas[tid].cntx = cntx; + datas[tid].rntm = rntm; + datas[tid].tid = tid; + datas[tid].gl_comm = gl_comm; + datas[tid].array = array; + + // Spawn additional threads for ids greater than 1. + if ( tid != 0 ) + bli_pthread_create( &pthreads[tid], NULL, &bls_l3_thread_entry, &datas[tid] ); + else + bls_l3_thread_entry( ( void* )(&datas[0]) ); + } + + // We shouldn't free the global communicator since it was already freed + // by the global communicator's chief thread in bli_l3_thrinfo_free() + // (called from the thread entry function). + + // Thread 0 waits for additional threads to finish. + for ( dim_t tid = 1; tid < n_threads; tid++ ) + { + bli_pthread_join( pthreads[tid], NULL ); + } + + // Check the array_t back into the small block allocator. Similar to the + // check-out, this is done using a lock embedded within the sba to ensure + // mutual exclusion. + bli_sba_checkin_array( array ); + + #ifdef BLIS_ENABLE_MEM_TRACING + printf( "bli_l3_thread_decorator().pth: " ); + #endif + bli_free_intl( pthreads ); + + #ifdef BLIS_ENABLE_MEM_TRACING + printf( "bli_l3_thread_decorator().pth: " ); + #endif + bli_free_intl( datas ); +} + +#endif + diff --git a/sandbox/gemmlike/thread/bls_l3_decor_pthreads.h b/sandbox/gemmlike/thread/bls_l3_decor_pthreads.h new file mode 100644 index 0000000000..ef5c3bad45 --- /dev/null +++ b/sandbox/gemmlike/thread/bls_l3_decor_pthreads.h @@ -0,0 +1,47 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef BLIS_SBX_L3_DECOR_PTHREADS_H +#define BLIS_SBX_L3_DECOR_PTHREADS_H + +// Definitions specific to situations when POSIX multithreading is enabled. +#ifdef BLIS_ENABLE_PTHREADS + +// Thread entry point prototype. +void* bls_l3_thread_entry( void* data_void ); + +#endif + +#endif + diff --git a/sandbox/gemmlike/thread/bls_l3_decor_single.c b/sandbox/gemmlike/thread/bls_l3_decor_single.c new file mode 100644 index 0000000000..7d9017dcd5 --- /dev/null +++ b/sandbox/gemmlike/thread/bls_l3_decor_single.c @@ -0,0 +1,141 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#ifndef BLIS_ENABLE_MULTITHREADING + +#define SKIP_THRINFO_TREE + +void bls_l3_thread_decorator + ( + l3sbxint_t func, + opid_t family, + //pack_t schema_a, + //pack_t schema_b, + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + rntm_t* rntm + ) +{ + // For sequential execution, we use only one thread. + const dim_t n_threads = 1; + + // NOTE: The sba was initialized in bli_init(). + + // Check out an array_t from the small block allocator. This is done + // with an internal lock to ensure only one application thread accesses + // the sba at a time. bli_sba_checkout_array() will also automatically + // resize the array_t, if necessary. + array_t* restrict array = bli_sba_checkout_array( n_threads ); + + // Access the pool_t* for thread 0 and embed it into the rntm. + bli_sba_rntm_set_pool( 0, array, rntm ); + + // Set the packing block allocator field of the rntm. + bli_membrk_rntm_set_membrk( rntm ); + +#ifndef SKIP_THRINFO_TREE + // Allcoate a global communicator for the root thrinfo_t structures. + thrcomm_t* restrict gl_comm = bli_thrcomm_create( rntm, n_threads ); +#endif + + + { + // NOTE: We don't need to create another copy of the rntm_t since + // it was already copied in one of the high-level oapi functions. + rntm_t* restrict rntm_p = rntm; + + // There is only one thread id (for the thief thread). + const dim_t tid = 0; + + // Use the thread id to access the appropriate pool_t* within the + // array_t, and use it to set the sba_pool field within the rntm_t. + // If the pool_t* element within the array_t is NULL, it will first + // be allocated/initialized. + // NOTE: This is commented out because, in the single-threaded case, + // this is redundant since it's already been done above. + //bli_sba_rntm_set_pool( tid, array, rntm_p ); + +#ifndef SKIP_THRINFO_TREE + thrinfo_t* thread = NULL; + + // Create the root node of the thread's thrinfo_t structure. + bli_l3_sup_thrinfo_create_root( tid, gl_comm, rntm_p, &thread ); +#else + // This optimization allows us to use one of the global thrinfo_t + // objects for single-threaded execution rather than grow one from + // scratch. The key is that bli_thrinfo_sup_grow(), which is called + // from within the variants, will immediately return if it detects + // that the thrinfo_t* passed into it is either + // &BLIS_GEMM_SINGLE_THREADED or &BLIS_PACKM_SINGLE_THREADED. + thrinfo_t* thread = &BLIS_GEMM_SINGLE_THREADED; + + ( void )tid; +#endif + + func + ( + alpha, + a, + b, + beta, + c, + cntx, + rntm_p, + thread + ); + +#ifndef SKIP_THRINFO_TREE + // Free the current thread's thrinfo_t structure. + bli_l3_sup_thrinfo_free( rntm_p, thread ); +#endif + } + + // We shouldn't free the global communicator since it was already freed + // by the global communicator's chief thread in bli_l3_thrinfo_free() + // (called above). + + // Check the array_t back into the small block allocator. Similar to the + // check-out, this is done using a lock embedded within the sba to ensure + // mutual exclusion. + bli_sba_checkin_array( array ); +} + +#endif + diff --git a/sandbox/gemmlike/thread/bls_l3_decor_single.h b/sandbox/gemmlike/thread/bls_l3_decor_single.h new file mode 100644 index 0000000000..211a43a894 --- /dev/null +++ b/sandbox/gemmlike/thread/bls_l3_decor_single.h @@ -0,0 +1,44 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef BLIS_SBX_L3_DECOR_SINGLE_H +#define BLIS_SBX_L3_DECOR_SINGLE_H + +// Definitions specific to situations when multithreading is disabled. +#ifndef BLIS_ENABLE_MULTITHREADING + +#endif + +#endif + diff --git a/sandbox/power10/bli_gemmnat.c b/sandbox/power10/bli_gemmnat.c index b2dabd29aa..846ccd35a8 100644 --- a/sandbox/power10/bli_gemmnat.c +++ b/sandbox/power10/bli_gemmnat.c @@ -32,7 +32,14 @@ */ -// This file is needed for the BLIS build system. +// Given the current architecture of BLIS sandboxes, bli_gemmnat() is the +// entry point to any sandbox implementation. + +// NOTE: This function is implemented identically to the function that it +// overrides in frame/ind/oapi/bli_l3_nat_oapi.c. This means that we are +// forgoing the option of customizing the implementations that underlie +// bli_gemm() and bli_?gemm(). Any new code defined in this sandbox +// directory, however, will be included in the BLIS. #include "blis.h" diff --git a/so_version b/so_version index 8efd5969fe..436b8f7fa7 100644 --- a/so_version +++ b/so_version @@ -1,2 +1,2 @@ -3 -2.0 +4 +0.0 diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index d116e942d0..77b746ba94 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -4,169 +4,169 @@ add_definitions(-DBLAS="AOCL") add_executable(TestAminv test_aminv.c) target_link_libraries(TestAminv debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP) - target_link_libraries(TestAminv OpenMP::OpenMP_CXX) +if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) + target_link_libraries(TestAminv "${OMP_LIB}") endif() target_link_libraries(TestAminv optimized "${LIB_NAME}.lib") add_executable(TestAxpyv test_axpyv.c) target_link_libraries(TestAxpyv debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP) - target_link_libraries(TestAxpyv OpenMP::OpenMP_CXX) +if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) + target_link_libraries(TestAxpyv "${OMP_LIB}") endif() target_link_libraries(TestAxpyv optimized "${LIB_NAME}.lib") add_executable(TestAxpbyv test_axpbyv.c) target_link_libraries(TestAxpbyv debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP) - target_link_libraries(TestAxpbyv OpenMP::OpenMP_CXX) +if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) + target_link_libraries(TestAxpbyv "${OMP_LIB}") endif() target_link_libraries(TestAxpbyv optimized "${LIB_NAME}.lib") add_executable(TestCopyv test_copyv.c) target_link_libraries(TestCopyv debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP) - target_link_libraries(TestCopyv OpenMP::OpenMP_CXX) +if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) + target_link_libraries(TestCopyv "${OMP_LIB}") endif() target_link_libraries(TestCopyv optimized "${LIB_NAME}.lib") add_executable(TestCabs1 test_cabs1.c) target_link_libraries(TestCabs1 debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP) - target_link_libraries(TestCabs1 OpenMP::OpenMP_CXX) +if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) + target_link_libraries(TestCabs1 "${OMP_LIB}") endif() target_link_libraries(TestCabs1 optimized "${LIB_NAME}.lib") add_executable(TestDotv test_dotv.c) target_link_libraries(TestDotv debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP) - target_link_libraries(TestDotv OpenMP::OpenMP_CXX) +if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) + target_link_libraries(TestDotv "${OMP_LIB}") endif() target_link_libraries(TestDotv optimized "${LIB_NAME}.lib") add_executable(TestGemm test_gemm.c) target_link_libraries(TestGemm debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP) - target_link_libraries(TestGemm OpenMP::OpenMP_CXX) +if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) + target_link_libraries(TestGemm "${OMP_LIB}") endif() target_link_libraries(TestGemm optimized "${LIB_NAME}.lib") add_executable(TestGemmBatch test_gemm_batch.c) target_link_libraries(TestGemmBatch debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP) - target_link_libraries(TestGemmBatch OpenMP::OpenMP_CXX) +if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) + target_link_libraries(TestGemmBatch "${OMP_LIB}") endif() target_link_libraries(TestGemmBatch optimized "${LIB_NAME}.lib") add_executable(TestGemm3m test_gemm3m.c) target_link_libraries(TestGemm3m debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP) - target_link_libraries(TestGemm3m OpenMP::OpenMP_CXX) +if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) + target_link_libraries(TestGemm3m "${OMP_LIB}") endif() target_link_libraries(TestGemm3m optimized "${LIB_NAME}.lib") add_executable(TestGemmt test_gemmt.c) target_link_libraries(TestGemmt debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP) - target_link_libraries(TestGemmt OpenMP::OpenMP_CXX) +if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) + target_link_libraries(TestGemmt "${OMP_LIB}") endif() target_link_libraries(TestGemmt optimized "${LIB_NAME}.lib") add_executable(TestGemv test_gemv.c) target_link_libraries(TestGemv debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP) - target_link_libraries(TestGemv OpenMP::OpenMP_CXX) +if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) + target_link_libraries(TestGemv "${OMP_LIB}") endif() target_link_libraries(TestGemv optimized "${LIB_NAME}.lib") add_executable(TestGer test_ger.c) target_link_libraries(TestGer debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP) - target_link_libraries(TestGer OpenMP::OpenMP_CXX) +if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) + target_link_libraries(TestGer "${OMP_LIB}") endif() target_link_libraries(TestGer optimized "${LIB_NAME}.lib") add_executable(TestHemm test_hemm.c) target_link_libraries(TestHemm debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP) - target_link_libraries(TestHemm OpenMP::OpenMP_CXX) +if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) + target_link_libraries(TestHemm "${OMP_LIB}") endif() target_link_libraries(TestHemm optimized "${LIB_NAME}.lib") add_executable(TestHemv test_hemv.c) target_link_libraries(TestHemv debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP) - target_link_libraries(TestHemv OpenMP::OpenMP_CXX) +if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) + target_link_libraries(TestHemv "${OMP_LIB}") endif() target_link_libraries(TestHemv optimized "${LIB_NAME}.lib") add_executable(TestHer test_her.c) target_link_libraries(TestHer debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP) - target_link_libraries(TestHer OpenMP::OpenMP_CXX) +if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) + target_link_libraries(TestHer "${OMP_LIB}") endif() target_link_libraries(TestHer optimized "${LIB_NAME}.lib") add_executable(TestHer2 test_her2.c) target_link_libraries(TestHer2 debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP) - target_link_libraries(TestHer2 OpenMP::OpenMP_CXX) +if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) + target_link_libraries(TestHer2 "${OMP_LIB}") endif() target_link_libraries(TestHer2 optimized "${LIB_NAME}.lib") add_executable(TestHer2k test_her2k.c) target_link_libraries(TestHer2k debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP) - target_link_libraries(TestHer2k OpenMP::OpenMP_CXX) +if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) + target_link_libraries(TestHer2k "${OMP_LIB}") endif() target_link_libraries(TestHer2k optimized "${LIB_NAME}.lib") add_executable(TestHerk test_herk.c) target_link_libraries(TestHerk debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP) - target_link_libraries(TestHerk OpenMP::OpenMP_CXX) +if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) + target_link_libraries(TestHerk "${OMP_LIB}") endif() target_link_libraries(TestHerk optimized "${LIB_NAME}.lib") add_executable(TestScalv test_scalv.c) target_link_libraries(TestScalv debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP) - target_link_libraries(TestScalv OpenMP::OpenMP_CXX) +if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) + target_link_libraries(TestScalv "${OMP_LIB}") endif() target_link_libraries(TestScalv optimized "${LIB_NAME}.lib") add_executable(TestSwapv test_swapv.c) target_link_libraries(TestSwapv debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP) - target_link_libraries(TestSwapv OpenMP::OpenMP_CXX) +if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) + target_link_libraries(TestSwapv "${OMP_LIB}") endif() target_link_libraries(TestSwapv optimized "${LIB_NAME}.lib") add_executable(TestTrmm test_trmm.c) target_link_libraries(TestTrmm debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP) - target_link_libraries(TestTrmm OpenMP::OpenMP_CXX) +if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) + target_link_libraries(TestTrmm "${OMP_LIB}") endif() target_link_libraries(TestTrmm optimized "${LIB_NAME}.lib") add_executable(TestTrmv test_trmv.c) target_link_libraries(TestTrmv debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP) - target_link_libraries(TestTrmv OpenMP::OpenMP_CXX) +if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) + target_link_libraries(TestTrmv "${OMP_LIB}") endif() target_link_libraries(TestTrmv optimized "${LIB_NAME}.lib") add_executable(TestTrsm test_trsm.c) target_link_libraries(TestTrsm debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP) - target_link_libraries(TestTrsm OpenMP::OpenMP_CXX) +if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) + target_link_libraries(TestTrsm "${OMP_LIB}") endif() target_link_libraries(TestTrsm optimized "${LIB_NAME}.lib") add_executable(TestTrsv test_trsv.c) target_link_libraries(TestTrsv debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP) - target_link_libraries(TestTrsv OpenMP::OpenMP_CXX) +if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) + target_link_libraries(TestTrsv "${OMP_LIB}") endif() target_link_libraries(TestTrsv optimized "${LIB_NAME}.lib") diff --git a/test/test_gemm.c b/test/test_gemm.c index 81b7e36616..cc50eb04ae 100644 --- a/test/test_gemm.c +++ b/test/test_gemm.c @@ -48,14 +48,14 @@ //#define CBLAS // Uncomment to enable progress printing. -//#define PROGRESS_ENABLED +#define PROGRESS_ENABLED #ifdef PROGRESS_ENABLED -dim_t AOCL_progress(char *api, - dim_t lapi, - dim_t progress, - dim_t current_thread, - dim_t total_threads) +dim_t AOCL_progress( const char* const api, + const dim_t lapi, + const dim_t progress, + const dim_t current_thread, + const dim_t total_threads ) { printf("\n%s, len = %ld, nt = %ld, tid = %ld, Processed %ld Elements", api, lapi, total_threads, current_thread, progress); diff --git a/test/test_trsm.c b/test/test_trsm.c index f6709f5d7f..3c48015967 100644 --- a/test/test_trsm.c +++ b/test/test_trsm.c @@ -54,11 +54,11 @@ //#define PROGRESS_ENABLED #ifdef PROGRESS_ENABLED -dim_t AOCL_progress(char *api, - dim_t lapi, - dim_t progress, - dim_t current_thread, - dim_t total_threads) +dim_t AOCL_progress( const char* const api, + const dim_t lapi, + const dim_t progress, + const dim_t current_thread, + const dim_t total_threads ) { printf("\n%s, len = %ld, nt = %ld, tid = %ld, Processed %ld Elements", api, lapi, total_threads, current_thread, progress); diff --git a/testsuite/CMakeLists.txt b/testsuite/CMakeLists.txt index 85866926dd..b997b8a8d9 100644 --- a/testsuite/CMakeLists.txt +++ b/testsuite/CMakeLists.txt @@ -7,8 +7,8 @@ add_executable(test_libblis "") add_subdirectory(src) target_link_libraries(test_libblis debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP) - target_link_libraries(test_libblis OpenMP::OpenMP_CXX) +if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) + target_link_libraries(test_libblis "${OMP_LIB}") endif() target_link_libraries(test_libblis optimized "${LIB_NAME}.lib") diff --git a/testsuite/src/test_gemm.c b/testsuite/src/test_gemm.c index fc25e74095..0fbf54df36 100644 --- a/testsuite/src/test_gemm.c +++ b/testsuite/src/test_gemm.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2022, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -35,7 +35,6 @@ #include "blis.h" #include "test_libblis.h" -#define TEST_SQP 0// ENABLE to test sqp path. // Static variables. static char* op_str = "gemm"; @@ -243,18 +242,6 @@ void libblis_test_gemm_experiment sc_str[0], m, n, &c_save ); // Set alpha and beta. -#if TEST_SQP - if ( bli_obj_is_real( &c ) ) - { - bli_setsc( 1.0, 0.0, &alpha ); - bli_setsc( 1.0, 0.0, &beta ); - } - else - { - bli_setsc( 1.0, 0.0, &alpha ); - bli_setsc( 1.0, 0.0, &beta ); - } -#else if ( bli_obj_is_real( &c ) ) { bli_setsc( 1.2, 0.0, &alpha ); @@ -265,20 +252,12 @@ void libblis_test_gemm_experiment bli_setsc( 1.2, 0.8, &alpha ); bli_setsc( 0.9, 1.0, &beta ); } -#endif // Randomize A, B, and C, and save C. libblis_test_mobj_randomize( params, TRUE, &a ); libblis_test_mobj_randomize( params, TRUE, &b ); libblis_test_mobj_randomize( params, TRUE, &c ); bli_copym( &c, &c_save ); -//bli_setm( &BLIS_ONE, &a ); -//bli_setsc( 1.0, 0.0, &alpha ); -//bli_setsc( 0.0, 0.0, &beta ); - -//bli_setm( &BLIS_ONE, &a ); -//bli_setsc( 1.0, 0.0, &alpha ); -//bli_setsc( 0.0, 0.0, &beta ); // Apply the parameters. bli_obj_set_conjtrans( transa, &a ); @@ -458,30 +437,7 @@ void libblis_test_gemm_impl switch ( iface ) { case BLIS_TEST_SEQ_FRONT_END: -#if 0 -//bli_printm( "alpha", alpha, "%5.2f", "" ); -//bli_printm( "beta", beta, "%5.2f", "" ); -bli_printm( "a", a, "%5.2f", "" ); -bli_printm( "b", b, "%5.2f", "" ); -bli_printm( "c", c, "%5.2f", "" ); -#endif -//if ( bli_obj_length( b ) == 16 && -// bli_obj_stor3_from_strides( c, a, b ) == BLIS_CRR ) -//bli_printm( "c before", c, "%6.3f", "" ); - -#if TEST_SQP - if(bli_gemm_sqp(alpha,a,b,beta,c,NULL,NULL)!=BLIS_SUCCESS) - { - bli_gemm( alpha, a, b, beta, c ); - } -#else//TEST_SQP - bli_gemm( alpha, a, b, beta, c ); -#endif//TEST_SQP -#if 0 -if ( bli_obj_length( c ) == 12 && - bli_obj_stor3_from_strides( c, a, b ) == BLIS_RRR ) -bli_printm( "c after", c, "%6.3f", "" ); -#endif + bli_gemm( alpha, a, b, beta, c ); break; default: diff --git a/testsuite/src/test_gemmtrsm_ukr.c b/testsuite/src/test_gemmtrsm_ukr.c index b3916db6a1..a0cec45b92 100644 --- a/testsuite/src/test_gemmtrsm_ukr.c +++ b/testsuite/src/test_gemmtrsm_ukr.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -209,13 +209,32 @@ void libblis_test_gemmtrsm_ukr_experiment // Query a context. cntx = bli_gks_query_cntx(); +#if defined(BLIS_FAMILY_AMDZEN) || defined(BLIS_FAMILY_ZEN4) + /* Zen4 TRSM Fixme: + * + * TRSM and GEMM used different values of MR and NR, we need to ensure that + * Values used for packing are as per the MR and NR values expected by the kernels + * For now this issue exists only for zen4 hence override the values here if + * the family is BLIS_TRSM and architecture is zen4 + * + * We need to override the values here as well as the packing and compute + * kernels are invoked directly from here (instead of BLIS/BLAS call.) + * + * We need to revisit this when TRSM AVX-512 kernels are implemented. + */ + if (bli_arch_query_id() == BLIS_ARCH_ZEN4) + { + bli_zen4_override_trsm_blkszs(cntx); + } +#endif + // Use the datatype of the first char in the datatype combination string. bli_param_map_char_to_blis_dt( dc_str[0], &datatype ); // Map the dimension specifier to actual dimensions. k = libblis_test_get_dim_from_prob_size( op->dim_spec[0], p_cur ); - // Fix m and n to MR and NR, respectively. + m = bli_cntx_get_blksz_def_dt( datatype, BLIS_MR, cntx ); n = bli_cntx_get_blksz_def_dt( datatype, BLIS_NR, cntx ); @@ -224,6 +243,7 @@ void libblis_test_gemmtrsm_ukr_experiment ldap = bli_cntx_get_blksz_max_dt( datatype, BLIS_MR, cntx ); ldbp = bli_cntx_get_blksz_max_dt( datatype, BLIS_NR, cntx ); + // Store the register blocksizes so that the driver can retrieve the // values later when printing results. op->dim_aux[0] = m; @@ -433,6 +453,7 @@ bli_printm( "ap", &ap, "%5.2f", "" ); bli_cntl_free( cntl_b, &BLIS_PACKM_SINGLE_THREADED ); #endif + // Free the packed objects. bli_obj_free( &ap ); bli_obj_free( &bp ); @@ -442,6 +463,20 @@ bli_printm( "ap", &ap, "%5.2f", "" ); bli_obj_free( &b ); bli_obj_free( &c11 ); bli_obj_free( &c11_save ); + +#if defined(BLIS_FAMILY_AMDZEN) || defined(BLIS_FAMILY_ZEN4) + /* Zen4 TRSM Fixme: + * + * We have overrding the block sizes at the start of this function + * Since the context is created only once we need to ensure that the + * default block sizes are restored for the subsequent operations. + */ + if (bli_arch_query_id() == BLIS_ARCH_ZEN4) + { + bli_zen4_restore_default_blkszs(cntx); + } +#endif + } diff --git a/testsuite/src/test_libblis.c b/testsuite/src/test_libblis.c index 6bf58831c1..b904094267 100644 --- a/testsuite/src/test_libblis.c +++ b/testsuite/src/test_libblis.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2020, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2022, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -2386,7 +2386,11 @@ void libblis_test_op_driver // Mark this operation as done. - op->test_done = TRUE; + if ( tdata->id == 0 ) + op->test_done = TRUE; + + // Wait here so that all threads know we are done + bli_pthread_barrier_wait( tdata->barrier ); } diff --git a/version b/version index 944880fa15..fcdb2e109f 100644 --- a/version +++ b/version @@ -1 +1 @@ -3.2.0 +4.0.0