Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

gemm falling back to the three-loop implementation when invoked as a function template, even for matrices of the same datatype #31

Open
YibaiMeng opened this issue Feb 2, 2023 · 5 comments
Labels
enhancement New feature or request not a bug Not considered a bug

Comments

@YibaiMeng
Copy link

YibaiMeng commented Feb 2, 2023

blaspp has a fallback gemm in line 90 of include/blas/gemm.hh, as shown below. It is a fallback naive three-loop implementation of gemm intended to be used where the data type of all three matrices are different.

template <typename TA, typename TB, typename TC>  
void gemm(...

However, calls in the form blas::gemm<double> where all matrices are of the same data type somehow got compiled to that fallback function. What's more, when the matrices are in row major format, the correct specialized blas function is invoked. (This is because when that happens, the code calls the column major version (line 103). However, it does not specify the matrix data types, leading to the correct function being used. ) While for column major matrices, the three-loop version is called.

This behavior is very unexpected: when calling blas::gemm<double>, most users would expect we are using a specialize kernel for double precision. This could lead to perplexing performance behaviors, where row major gemms are normal and column major gemms are slow. Although blas::gemm<T> is undocumented and therefore it's not a bug, I think it's better to either use the kernels like the overloaded variants does, or issue a compiler error/warning.

cc @rileyjmurray @burlen

@YibaiMeng YibaiMeng changed the title gemm falling back to the three-loop implementation when invoked as a function template gemm falling back to the three-loop implementation when invoked as a function template, even for matrices of the same datatype Feb 2, 2023
@mgates3
Copy link
Collaborator

mgates3 commented Feb 2, 2023

Thanks for reporting this. I never really thought about invoking BLAS++ this way, where the template parameter is specified explicitly. Usually I expect template parameters to be determined implicitly from the arguments. E.g., I expect users to call

double *A, *B, *C;
blas::gemm( ..., A, lda, B, ldb, beta, C, ldc, ... );

rather than the more explicit calls:

blas::gemm<double>( ..., A, lda, B, ldb, beta, C, ldc, ... );
or
blas::gemm<double, double, double>( ..., A, lda, B, ldb, beta, C, ldc, ... );

The row-major behavior was a bit surprising, so thanks for pointing that out. It's clear what's happening because gemm calls itself using the first convention.

We could add template specializations for cases that map to optimized routines (e.g., sgemm).

The design uses overloaded routines so that C++ can convert scalar data types, e.g.,

std::complex<float> *A, *B, *C;
blas::gemm( ..., -1.0, A, lda, B, ldb, 1.0, C, ldc, ... );

where alpha and beta are doubles that get implicitly converted to std::complex. This works with overloads, but not with templates.

@mgates3
Copy link
Collaborator

mgates3 commented Feb 2, 2023

BTW, don't expect a change in BLAS++ any time soon. We have other things that have higher priority. In the meantime, simply omit the template parameters to get the desired behavior.

@mgates3 mgates3 added the enhancement New feature or request label Feb 2, 2023
@burlen
Copy link

burlen commented Feb 2, 2023

This is not really a bug, but one might want to prevent users from accidentally doing this since the performance was really bad and it is not obvious what's going on. To that end one might use SFINAE to disable the template when all 3 types are the same. In that case the compiler still finds the overloads when calling with identical types, finds the template function when the types are different, and errors out if the template is explicitly called and the types are the same. It's quite a small change. Here is a minimal complete example where STL's type_traits are used to hide the template function from the compiler

#include <iostream>
#include <type_traits>

// this template handles data of different types. SFINAE hides it from the
// compiler when the types are the same. This is what we want since there
// are overloads for faster implementations in that case                                                                                                                                                                                
template <typename TA, typename TB, typename TC, 
    typename std::enable_if<
             !(std::is_same<TA,TB>{} && std::is_same<TA,TC>{})
             , bool>::type = true>
void gemm(
    TA const *A, 
    TB const *B, 
    TC       *C) 
{
    std::cerr << "You called the slow method! TA != TB != TC" << std::endl;
}
​
​/// overloads invoking faster implementation for when the types are the same 
void gemm(
    double const *A, 
    double const *B, 
    double *C) 
{
    std::cerr << "You called the faster implementation" << std::endl;
}
 

int main()
{
    float *Af = nullptr;
    double *Ad = nullptr;
    double *Bd = nullptr;
    double *Cd = nullptr;
​
    gemm(Af,Bd,Cd); // OK, template function called arguments have different type
    gemm(Ad,Bd,Cd); // OK, fast overload called arguments have the same type
        
    gemm<double>(Ad,Bd,Cd); // WRONG, specifying the template parameter forces
                            // the compiler to choose the template even though
                            // there's an explicit overload for this case. SFINAE
                            // can be used to remove the template and catch the
                            // error at compile timereturn 0;
}

This program fails to compile because of the gemm<double>(...) with the following output from clang

$ clang++ -std=c++17 test_is_same.cpp
test_is_same.cpp:49:5: error: no matching function for call to 'gemm'
    gemm<double>(A,B,C);
    ^~~~~~~~~~~~
test_is_same.cpp:10:6: note: candidate template ignored: requirement '!(std::is_same<double, double>{} && std::is_same<double, double>{})' was not satisfied [with TA = double, TB = double, TC = double]
void gemm(
     ^
1 error generated.

The compiler gives a pretty good error message in this case! When the gemm<double>(...) line is commented out the program compiles, runs, and as expected prints:

$ clang++ -std=c++17 test_is_same.cpp
$./a.out 
You called the slow method! TA != TB != TC
You called the faster implementation

@mgates3
Copy link
Collaborator

mgates3 commented Feb 2, 2023

I could see instances when TA = TB = TC but they are not one of the 4 standard types float, double, complex-float, complex-double. E.g., half, double-double, int. The thing to disable or redirect is really just the 4 standard types.

Providing specializations isn't all that complicated, but there are quite a few routines in BLAS to update.

Incidentally, you can blame Weslley for the Level 3 BLAS templates. Initially we had only Level 1 and 2 templates.

@burlen
Copy link

burlen commented Feb 2, 2023

The thing to disable or redirect is really just the 4 standard types.

the following might be closer ...

template <typename T, typename ...Ts>
using all_same = std::conjunction<std::is_same<T,Ts>...>;

template< typename T >
struct is_blaspp_type
     : std::integral_constant<
         bool,
         std::is_same<float, typename std::remove_cv<T>::type>::value
         || std::is_same<double, typename std::remove_cv<T>::type>::value
         || std::is_same<std::complex<float>, typename std::remove_cv<T>::type>::value
         || std::is_same<std::complex<double>, typename std::remove_cv<T>::type>::value
     > {};

template <typename T, typename ...Ts>
using all_blaspp_type = std::conjunction<is_blaspp_type<T>, is_blaspp_type<Ts>...>;


// this template is the fallback when faster implementation is not available.
// SFINAE hides it from the compiler when all 3 types are the same or when
// all of the types are one of blaspp's standard types. This is what we want
// since there are overloads for faster implementations in those cases  
template <typename TA, typename TB, typename TC,
    typename std::enable_if<
            !all_same<TA,TB,TC>{} || !all_blaspp_type<TA,TB,TC>{}
            , bool>::type = true>
void gemm(
    TA const *A,
    TB const *B,
    TC       *C)
{
    std::cerr << "You called the template implementation. all same? "
        << all_same<TA,TB,TC>::value << " all supported? "
        << all_blaspp_type<TA,TB,TC>::value << std::endl;
}

@mgates3 mgates3 added the not a bug Not considered a bug label Feb 6, 2023
burlen added a commit to BallisticLA/RandLAPACK that referenced this issue Feb 8, 2023
blas mixes templates for slower unoptimized fall back implementations
and overloads for faster optimized implementations. specifying the
template type in calls to blas functions tells the compiler you want to
use the slower template implementation.
For more info see: icl-utk-edu/blaspp#31
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request not a bug Not considered a bug
Projects
None yet
Development

No branches or pull requests

3 participants