diff --git a/Cxx11/xgemm-cblas.cc b/Cxx11/xgemm-cblas.cc index a662af405..d9190c535 100644 --- a/Cxx11/xgemm-cblas.cc +++ b/Cxx11/xgemm-cblas.cc @@ -60,79 +60,100 @@ #include "prk_util.h" #if defined(MKL) -#include + #include + #define PRK_INT MKL_INT + #define PRK_F16 MKL_F16 + #define USE_F16 1 + #define PRK_BF16 MKL_BF16 + #define USE_BF16 1 #elif defined(ACCELERATE) -// The location of cblas.h is not in the system include path when -framework Accelerate is provided. -#include + // The location of cblas.h is not in the system include path when -framework Accelerate is provided. + #include #else -#include -#endif - -#ifndef MKL_INT -#define MKL_INT int + // assume OpenBLAS for now + #include + #ifdef OPENBLAS_USE64BITINT + #define PRK_INT long + #else + #define PRK_INT int + #endif + #define PRK_BF16 bfloat16 #endif template void prk_gemm(const CBLAS_LAYOUT Layout, const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, - const MKL_INT M, const MKL_INT N, const MKL_INT K, + const PRK_INT M, const PRK_INT N, const PRK_INT K, const TC alpha, - const TAB * A, const MKL_INT lda, - const TAB * B, const MKL_INT ldb, + const TAB * A, const PRK_INT lda, + const TAB * B, const PRK_INT ldb, const TC beta, - TC * C, const MKL_INT ldc) + TC * C, const PRK_INT ldc) { std::cerr << "No valid template match for type T" << std::endl; std::abort(); } -#ifdef MKL_F16 +#ifdef PRK_F16 template <> void prk_gemm(const CBLAS_LAYOUT Layout, const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, - const MKL_INT M, const MKL_INT N, const MKL_INT K, - const MKL_F16 alpha, - const MKL_F16 * A, const MKL_INT lda, - const MKL_F16 * B, const MKL_INT ldb, - const MKL_F16 beta, - MKL_F16 * C, const MKL_INT ldc) + const PRK_INT M, const PRK_INT N, const PRK_INT K, + const PRK_F16 alpha, + const PRK_F16 * A, const PRK_INT lda, + const PRK_F16 * B, const PRK_INT ldb, + const PRK_F16 beta, + PRK_F16 * C, const PRK_INT ldc) { cblas_hgemm(Layout, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc); } #endif -#ifdef MKL_BF16 +#ifdef USE_BF16 template <> void prk_gemm(const CBLAS_LAYOUT Layout, const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, - const MKL_INT M, const MKL_INT N, const MKL_INT K, + const PRK_INT M, const PRK_INT N, const PRK_INT K, const float alpha, - const MKL_BF16 * A, const MKL_INT lda, - const MKL_BF16 * B, const MKL_INT ldb, + const PRK_BF16 * A, const PRK_INT lda, + const PRK_BF16 * B, const PRK_INT ldb, const float beta, - float * C, const MKL_INT ldc) + float * C, const PRK_INT ldc) { - // cblas_gemm_bf16bf16f32(const CBLAS_LAYOUT Layout, const CBLAS_TRANSPOSE TransA, - // const CBLAS_TRANSPOSE TransB, - // const MKL_INT M, const MKL_INT N, const MKL_INT K, - // const float alpha, const MKL_BF16 *A, const MKL_INT lda, - // const MKL_BF16 *B, const MKL_INT ldb, const float beta, - // float *C, const MKL_INT ldc); +#ifdef MKL + // MKL + // cblas_gemm_bf16bf16f32(const CBLAS_LAYOUT Layout, + // const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, + // const PRK_INT M, const PRK_INT N, const PRK_INT K, + // const float alpha, const PRK_BF16 *A, const PRK_INT lda, + // const PRK_BF16 *B, const PRK_INT ldb, + // const float beta, float *C, const PRK_INT ldc); cblas_gemm_bf16bf16f32(Layout, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc); +#else + // OpenBLAS + // cblas_sbgemm(OPENBLAS_CONST enum CBLAS_ORDER Order, + // OPENBLAS_CONST enum CBLAS_TRANSPOSE TransA, OPENBLAS_CONST enum CBLAS_TRANSPOSE TransB, + // OPENBLAS_CONST blasint M, OPENBLAS_CONST blasint N, OPENBLAS_CONST blasint K, + // OPENBLAS_CONST float alpha, OPENBLAS_CONST bfloat16 *A, OPENBLAS_CONST blasint lda, + // OPENBLAS_CONST bfloat16 *B, OPENBLAS_CONST blasint ldb, + // OPENBLAS_CONST float beta, float *C, OPENBLAS_CONST blasint ldc); + cblas_sbgemm(Layout, TransA, TransB, + M, N, K, alpha, A, lda, B, ldb, beta, C, ldc); +#endif } #endif template <> void prk_gemm(const CBLAS_LAYOUT Layout, const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, - const MKL_INT M, const MKL_INT N, const MKL_INT K, + const PRK_INT M, const PRK_INT N, const PRK_INT K, const float alpha, - const float * A, const MKL_INT lda, - const float * B, const MKL_INT ldb, + const float * A, const PRK_INT lda, + const float * B, const PRK_INT ldb, const float beta, - float * C, const MKL_INT ldc) + float * C, const PRK_INT ldc) { cblas_sgemm(Layout, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc); @@ -141,26 +162,26 @@ void prk_gemm(const CBLAS_LAYOUT Layout, template <> void prk_gemm(const CBLAS_LAYOUT Layout, const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, - const MKL_INT M, const MKL_INT N, const MKL_INT K, + const PRK_INT M, const PRK_INT N, const PRK_INT K, const double alpha, - const double * A, const MKL_INT lda, - const double * B, const MKL_INT ldb, + const double * A, const PRK_INT lda, + const double * B, const PRK_INT ldb, const double beta, - double * C, const MKL_INT ldc) + double * C, const PRK_INT ldc) { cblas_dgemm(Layout, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc); } -#ifdef MKL_BF16 +#ifdef USE_BF16 void run_BF16(int iterations, int order) { double gemm_time{0}; const size_t nelems = (size_t)order * (size_t)order; - auto A = new MKL_BF16[nelems]; - auto B = new MKL_BF16[nelems]; + auto A = new PRK_BF16[nelems]; + auto B = new PRK_BF16[nelems]; auto C = new float[nelems]; for (int i=0; i