Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better bf16 support #61

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions candle-kernels/src/affine.cu
Original file line number Diff line number Diff line change
@@ -28,7 +28,6 @@ extern "C" __global__ void FN_NAME( \
} \
} \

#if __CUDA_ARCH__ >= 800
#include "cuda_fp8.h"
#include "cuda_bf16.h"

@@ -37,11 +36,8 @@ AFFINE_OP(__nv_bfloat16, affine_bf16, x * mul + add)
#define F8E4M3_TO_FLOAT(x) __half2float(__nv_cvt_fp8_to_halfraw(x.__x, __NV_E4M3))

AFFINE_OP(__nv_fp8_e4m3, affine_f8_e4m3, __nv_fp8_e4m3(F8E4M3_TO_FLOAT(x) * F8E4M3_TO_FLOAT(mul) + F8E4M3_TO_FLOAT(add)))
#endif

#if __CUDA_ARCH__ >= 530
AFFINE_OP(__half, affine_f16, x * mul + add)
#endif

AFFINE_OP(float, affine_f32, x * mul + add)
AFFINE_OP(double, affine_f64, x * mul + add)
4 changes: 0 additions & 4 deletions candle-kernels/src/binary.cu
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
#include "binary_op_macros.cuh"
#include<stdint.h>

#if __CUDA_ARCH__ >= 800
#include "cuda_fp8.h"
#include "cuda_bf16.h"

@@ -32,9 +31,7 @@ BINARY_OP_OUT(__nv_fp8_e4m3, uint8_t, lt_f8_e4m3, F8E4M3_TO_FLOAT(x) < F8E4M3_TO
BINARY_OP_OUT(__nv_fp8_e4m3, uint8_t, le_f8_e4m3, F8E4M3_TO_FLOAT(x) <= F8E4M3_TO_FLOAT(y))
BINARY_OP_OUT(__nv_fp8_e4m3, uint8_t, gt_f8_e4m3, F8E4M3_TO_FLOAT(x) > F8E4M3_TO_FLOAT(y))
BINARY_OP_OUT(__nv_fp8_e4m3, uint8_t, ge_f8_e4m3, F8E4M3_TO_FLOAT(x) >= F8E4M3_TO_FLOAT(y))
#endif

#if __CUDA_ARCH__ >= 530
BINARY_OP(__half, badd_f16, x + y)
BINARY_OP(__half, bdiv_f16, x / y)
BINARY_OP(__half, bmul_f16, x * y)
@@ -47,7 +44,6 @@ BINARY_OP_OUT(__half, uint8_t, lt_f16, x < y)
BINARY_OP_OUT(__half, uint8_t, le_f16, x <= y)
BINARY_OP_OUT(__half, uint8_t, gt_f16, x > y)
BINARY_OP_OUT(__half, uint8_t, ge_f16, x >= y)
#endif

BINARY_OP(float, badd_f32, x + y)
BINARY_OP(double, badd_f64, x + y);
7 changes: 1 addition & 6 deletions candle-kernels/src/cast.cu
Original file line number Diff line number Diff line change
@@ -24,7 +24,6 @@ __device__ void cast_(
}
}

#if __CUDA_ARCH__ >= 800
#define F8E4M3_TO_FLOAT(x) __half2float(__nv_cvt_fp8_to_halfraw(x.__x, __NV_E4M3))

template <typename T>
@@ -71,7 +70,6 @@ __device__ void cast_fp8_into_(
}
}
}
#endif

template <typename S, typename T, typename I>
__device__ void cast_through(
@@ -143,9 +141,9 @@ extern "C" __global__ void FN_NAME( \
cast_through<SRC_TYPENAME, DST_TYPENAME, INT_TYPENAME>(numel, num_dims, info, inp, out); \
} \

#if __CUDA_ARCH__ >= 800
#include "cuda_fp8.h"
#include "cuda_bf16.h"

CAST_OP(__nv_bfloat16, __nv_bfloat16, cast_bf16_bf16)
CAST_OP(__nv_fp8_e4m3, __nv_fp8_e4m3, cast_f8_e4m3_f8_e4m3)

@@ -174,9 +172,7 @@ CAST_OP_FP8_INTO(int32_t, __nv_fp8_e4m3, cast_i32_f8_e4m3)
CAST_OP_FP8(__nv_fp8_e4m3, int32_t, cast_f8_e4m3_i32)
CAST_OP_FP8(__nv_fp8_e4m3, __nv_bfloat16, cast_f8_e4m3_bf16)
CAST_OP_FP8_INTO(__nv_bfloat16, __nv_fp8_e4m3, cast_bf16_f8_e4m3)
#endif

#if __CUDA_ARCH__ >= 530
CAST_OP(__half, __half, cast_f16_f16)

CAST_THROUGH_OP(__half, uint8_t, float, cast_f16_u8)
@@ -189,7 +185,6 @@ CAST_OP(float, __half, cast_f32_f16)
CAST_OP(double, __half, cast_f64_f16)
CAST_OP(int32_t, __half, cast_i32_f16 )
CAST_THROUGH_OP(__half, int32_t, float, cast_f16_i32)
#endif

CAST_OP(uint32_t, uint32_t, cast_u32_u32)
CAST_OP(uint32_t, uint8_t, cast_u32_u8 )
30 changes: 15 additions & 15 deletions candle-kernels/src/compatibility.cuh
Original file line number Diff line number Diff line change
@@ -39,21 +39,21 @@ __device__ double atomicAdd(double* address, double val) {
// The 16-bit __half floating-point version of atomicAdd() is only supported by devices of compute capability 7.x and higher.
// Solution adapted from https://github.com/torch/cutorch/blob/master/lib/THC/THCAtomics.cuh#L96-L119
__device__ __half atomicAdd(__half *address, __half val) {
// unsigned int *address_as_ui = (unsigned int *) ((char *)address - ((size_t)address & 2));
// unsigned int old = *address_as_ui;
// unsigned int assumed;
// bool unaligned = (size_t) address & 2;
// do {
// assumed = old;
// unsigned int hsum;
// hsum = unaligned ? (old >> 16) : (old & 0xffff);
// hsum = __half_as_ushort(__ushort_as_half(hsum) + val);
// old = atomicCAS(address_as_ui, assumed,
// unaligned ? (old & 0xffff) | (hsum << 16) : (old & 0xffff0000) | hsum
// );

// } while (assumed != old);
// return __ushort_as_half(unaligned ? (old >> 16) : (old & 0xffff));
unsigned int *address_as_ui = (unsigned int *) ((char *)address - ((size_t)address & 2));
unsigned int old = *address_as_ui;
unsigned int assumed;
bool unaligned = (size_t) address & 2;
do {
assumed = old;
unsigned int hsum;
hsum = unaligned ? (old >> 16) : (old & 0xffff);
hsum = __half_as_ushort(__ushort_as_half(hsum) + val);
old = atomicCAS(address_as_ui, assumed,
unaligned ? (old & 0xffff) | (hsum << 16) : (old & 0xffff0000) | hsum
);

} while (assumed != old);
return __ushort_as_half(unaligned ? (old >> 16) : (old & 0xffff));
}
#endif

4 changes: 0 additions & 4 deletions candle-kernels/src/conv.cu
Original file line number Diff line number Diff line change
@@ -691,7 +691,6 @@ extern "C" __global__ void FN_NAME( \
upsample_nearest2d<TYPENAME>(w_out, h_out, w_scale, h_scale, info, src, dst); \
} \

#if __CUDA_ARCH__ >= 800
#include "cuda_bf16.h"

CONV1D_OP(__nv_bfloat16, float, conv1d_bf16)
@@ -716,9 +715,7 @@ COL2IM1D_OP(__nv_bfloat16, col2im1d_bf16)
// IM2COL_OP(__nv_fp8_e4m3, im2col_f8_e5m)
// IM2COL1D_OP(__nv_fp8_e4m3, im2col1d_f8_e5m)
// COL2IM1D_OP(__nv_fp8_e4m3, col2im1d_f8_e5m)
#endif

#if __CUDA_ARCH__ >= 530
CONV1D_OP(__half, float, conv1d_f16)
CONV2D_OP(__half, float, conv2d_f16)
CONVT1D_OP(__half, float, conv_transpose1d_f16)
@@ -729,7 +726,6 @@ UPSAMPLE_NEAREST2D_OP(__half, upsample_nearest2d_f16)
IM2COL_OP(__half, im2col_f16)
IM2COL1D_OP(__half, im2col1d_f16)
COL2IM1D_OP(__half, col2im1d_f16)
#endif

CONV1D_OP(float, float, conv1d_f32)
CONV1D_OP(double, double, conv1d_f64)
11 changes: 3 additions & 8 deletions candle-kernels/src/cuda_utils.cuh
Original file line number Diff line number Diff line change
@@ -2,6 +2,9 @@
#include<stdint.h>
#include<cmath>

#include "cuda_fp8.h"
#include "cuda_bf16.h"

// TODO: This is often used to check that the data is contiguous so that
// kernels can be easily mapped. However this only returns true for row
// major, if all the inputs are column major, we could apply the fast path
@@ -191,7 +194,6 @@ __device__ __forceinline__ uint32_t ming(uint32_t a, uint32_t b) { return min(a,
__device__ __forceinline__ uint32_t maxg(uint32_t a, uint32_t b) { return max(a, b); }
__device__ __forceinline__ uint8_t ming(uint8_t a, uint8_t b) { return min(a, b); }
__device__ __forceinline__ uint8_t maxg(uint8_t a, uint8_t b) { return max(a, b); }
#if __CUDA_ARCH__ >= 530
__device__ __forceinline__ __half powg(__half a, __half b) { return __float2half(powf(__half2float(a), __half2float(b))); }
__device__ __forceinline__ bool isnang(__half a) { return __hisnan(a); }
__device__ __forceinline__ __half sqrtg(__half a) { return hsqrt(a); }
@@ -210,11 +212,6 @@ __device__ __forceinline__ __half logg(__half a) { return hlog(a); }
__device__ __forceinline__ __half expg(__half a) { return hexp(a); }
__device__ __forceinline__ __half absg(__half a) { return __habs(a); }
__device__ __forceinline__ __half copysigng(__half a, __half b) { return __float2half(copysignf(__half2float(a), __half2float(b))); }
#endif

#if __CUDA_ARCH__ >= 800
#include "cuda_fp8.h"
#include "cuda_bf16.h"

__device__ __forceinline__ __nv_bfloat16 powg(__nv_bfloat16 a, __nv_bfloat16 b) { return __float2bfloat16(powf(__bfloat162float(a), __bfloat162float(b))); }
__device__ __forceinline__ bool isnang(__nv_bfloat16 a) { return __hisnan(a); }
@@ -256,5 +253,3 @@ __device__ __forceinline__ __nv_fp8_e4m3 expg(__nv_fp8_e4m3 a) { return __nv_fp8
__device__ __forceinline__ __nv_fp8_e4m3 absg(__nv_fp8_e4m3 a) { return __nv_fp8_e4m3(fabsf(F8E4M3_TO_FLOAT(a))); }
__device__ __forceinline__ __nv_fp8_e4m3 copysigng(__nv_fp8_e4m3 a, __nv_fp8_e4m3 b) { return __nv_fp8_e4m3(copysignf(F8E4M3_TO_FLOAT(a), F8E4M3_TO_FLOAT(b))); }


#endif
11 changes: 4 additions & 7 deletions candle-kernels/src/fill.cu
Original file line number Diff line number Diff line change
@@ -40,18 +40,15 @@ COPY2D_OP(int16_t, copy2d_i16)
COPY2D_OP(int32_t, copy2d_i32)
COPY2D_OP(int64_t, copy2d_i64)

#if __CUDA_ARCH__ >= 530
extern "C" __global__ void fill_f16(__half *buf, __half value, const size_t numel) { fill_with(buf, value, numel); }
COPY2D_OP(__half, copy2d_f16)
#endif

#if __CUDA_ARCH__ >= 800
#include <cuda_bf16.h>
#include <cuda_fp8.h>

extern "C" __global__ void fill_f16(__half *buf, __half value, const size_t numel) { fill_with(buf, value, numel); }
COPY2D_OP(__half, copy2d_f16)

extern "C" __global__ void fill_bf16(__nv_bfloat16 *buf, __nv_bfloat16 value, const size_t numel) { fill_with(buf, value, numel); }
COPY2D_OP(__nv_bfloat16, copy2d_bf16)

extern "C" __global__ void fill_f8_e4m3(__nv_fp8_e4m3 *buf, __nv_fp8_e4m3 value, const size_t numel) { fill_with(buf, value, numel); }
COPY2D_OP(__nv_fp8_e4m3, copy2d_f8_e4m3)
#endif

4 changes: 1 addition & 3 deletions candle-kernels/src/fused_rms_norm.cu
Original file line number Diff line number Diff line change
@@ -76,7 +76,5 @@ extern "C" __global__ void FN_NAME(\
RMS_NORM_OP(rms_norm_f32, float)
RMS_NORM_OP(rms_norm_f16, __half)

#if __CUDA_ARCH__ >= 800
#include <cuda_bf16.h>
RMS_NORM_OP(rms_norm_bf16, __nv_bfloat16)
#endif
RMS_NORM_OP(rms_norm_bf16, __nv_bfloat16)
3 changes: 1 addition & 2 deletions candle-kernels/src/fused_rope.cu
Original file line number Diff line number Diff line change
@@ -189,7 +189,6 @@ extern "C" __global__ void rotary_embedding_kernel_neox_f64(
apply_rotary_embedding<double, true>(query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, token_idx, query_stride, key_stride);
}

#if __CUDA_ARCH__ >= 800
#include <cuda_bf16.h>
extern "C" __global__ void rotary_embedding_kernel_bf16(
const int64_t* __restrict__ positions, // [batch_size, seq_len] or [num_tokens]
@@ -228,4 +227,4 @@ extern "C" __global__ void rotary_embedding_kernel_neox_bf16(

apply_rotary_embedding<__nv_bfloat16, true>(query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, token_idx, query_stride, key_stride);
}
#endif

6 changes: 0 additions & 6 deletions candle-kernels/src/indexing.cu
Original file line number Diff line number Diff line change
@@ -99,7 +99,6 @@ __device__ void index_add(
}
}

#if __CUDA_ARCH__ >= 800
#define F8E4M3_TO_FLOAT(x) __half2float(__nv_cvt_fp8_to_halfraw(x.__x, __NV_E4M3))

template<typename I>
@@ -148,7 +147,6 @@ __device__ void index_add_f8(
}
}
}
#endif

#define IA_OP(TYPENAME, INDEX_TYPENAME, FN_NAME) \
extern "C" __global__ void FN_NAME( \
@@ -220,7 +218,6 @@ extern "C" __global__ void FN_NAME( \
) { scatter_add_f8(ids, inp, out, left_size, src_dim_size, dst_dim_size, right_size); } \


#if __CUDA_ARCH__ >= 800
#include "cuda_fp8.h"
#include "cuda_bf16.h"

@@ -265,9 +262,7 @@ SA_OP_F8(__nv_fp8_e4m3, int32_t, sa_i32_f8_e4m3)
SA_OP_F8(__nv_fp8_e4m3, int64_t, sa_i64_f8_e4m3)
SA_OP_F8(__nv_fp8_e4m3, uint32_t, sa_u32_f8_e4m3)
SA_OP_F8(__nv_fp8_e4m3, uint8_t, sa_u8_f8_e4m3)
#endif

#if __CUDA_ARCH__ >= 530
IS_OP(__half, int16_t, is_i16_f16)
IS_OP(__half, int32_t, is_i32_f16)
IS_OP(__half, int64_t, is_i64_f16)
@@ -288,7 +283,6 @@ SA_OP(__half, int32_t, sa_i32_f16)
SA_OP(__half, int64_t, sa_i64_f16)
SA_OP(__half, uint32_t, sa_u32_f16)
SA_OP(__half, uint8_t, sa_u8_f16)
#endif

IS_OP(float, int16_t, is_i16_f32)
IS_OP(double, int16_t, is_i16_f64)
10 changes: 3 additions & 7 deletions candle-kernels/src/kvconcat.cu
Original file line number Diff line number Diff line change
@@ -44,14 +44,10 @@ KVCONCAT_OP(uint8_t, kvconcat_u8)
KVCONCAT_OP(double, kvconcat_f64)
KVCONCAT_OP(float, kvconcat_f32)

#if __CUDA_ARCH__ >= 530
KVCONCAT_OP(__half, kvconcat_f16)
#endif

#if __CUDA_ARCH__ >= 800
#include "cuda_fp8.h"
#include "cuda_bf16.h"

KVCONCAT_OP(__nv_bfloat16, kvconcat_bf16)

KVCONCAT_OP(__half, kvconcat_f16)

KVCONCAT_OP(__nv_fp8_e4m3, kvconcat_f8_e4m3)
#endif
4 changes: 0 additions & 4 deletions candle-kernels/src/reduce.cu
Original file line number Diff line number Diff line change
@@ -571,7 +571,6 @@ fast_argmax(const size_t src_numel, const size_t el_to_sum_per_block,
rope_thd<TYPENAME>(src, cos, sin, dst, b, t, h, d); \
} \

#if __CUDA_ARCH__ >= 800
#include "cuda_bf16.h"
SOFTMAX_OP(__nv_bfloat16, float, softmax_bf16)
RMSNORM_OP(__nv_bfloat16, rmsnorm_bf16)
@@ -587,16 +586,13 @@ FAST_OP(__nv_bfloat16, fast_min_bf16, fast_max_bf16, fast_argmin_bf16, fast_argm
// LAYERNORM_OP(__nv_fp8_e4m3, layernorm_fp8_e4m3)
// ROPE_OP(__nv_fp8_e4m3, rope_fp8_e4m3, rope_i_fp8_e4m3, rope_thd_fp8_e4m3)
// FAST_OP(__nv_fp8_e4m3, fast_min_fp8_e4m3, fast_max_fp8_e4m3, fast_argmin_fp8_e4m3, fast_argmax_fp8_e4m3, fast_sum_fp8_e4m3)
#endif

#if __CUDA_ARCH__ >= 530
SOFTMAX_OP(__half, float, softmax_f16)
RMSNORM_OP(__half, rmsnorm_f16)
LAYERNORM_OP(__half, layernorm_f16)
ROPE_OP(__half, rope_f16, rope_i_f16, rope_thd_f16)
SUM_OP(__half, sum_f16)
FAST_OP(__half, fast_min_f16, fast_max_f16, fast_argmin_f16, fast_argmax_f16, fast_sum_f16)
#endif

SUM_OP(float, sum_f32)
SUM_OP(double, sum_f64)
4 changes: 0 additions & 4 deletions candle-kernels/src/sort.cu
Original file line number Diff line number Diff line change
@@ -73,17 +73,13 @@ extern "C" __global__ void asort_desc_##RUST_NAME( \
k_argsort<SORT_ORDER_DESC>(x, dst, ncols, ncols_pad); \
} \

#if __CUDA_ARCH__ >= 800
#include "cuda_bf16.h"
ASORT_OP(__nv_bfloat16, bf16)

// NOTE: No sort ops for f8
// ASORT_OP(__nv_fp8_e4m3, fp8_e4m3)
#endif

#if __CUDA_ARCH__ >= 530
ASORT_OP(__half, f16)
#endif

ASORT_OP(float, f32)
ASORT_OP(double, f64)
4 changes: 0 additions & 4 deletions candle-kernels/src/ternary.cu
Original file line number Diff line number Diff line change
@@ -32,7 +32,6 @@ extern "C" __global__ void FN_NAME( \
} \
} \

#if __CUDA_ARCH__ >= 800
#include "cuda_fp8.h"
#include "cuda_bf16.h"

@@ -47,15 +46,12 @@ WHERE_OP(__nv_fp8_e4m3, int32_t, where_i32_fp8_e4m3)
WHERE_OP(__nv_fp8_e4m3, int64_t, where_i64_fp8_e4m3)
WHERE_OP(__nv_fp8_e4m3, uint32_t, where_u32_fp8_e4m3)
WHERE_OP(__nv_fp8_e4m3, uint8_t, where_u8_fp8_e4m3)
#endif

#if __CUDA_ARCH__ >= 530
WHERE_OP(__half, int16_t, where_i16_f16)
WHERE_OP(__half, int32_t, where_i32_f16)
WHERE_OP(__half, int64_t, where_i64_f16)
WHERE_OP(__half, uint32_t, where_u32_f16)
WHERE_OP(__half, uint8_t, where_u8_f16)
#endif

WHERE_OP(float, int16_t, where_i16_f32)
WHERE_OP(double, int16_t, where_i16_f64)
4 changes: 0 additions & 4 deletions candle-kernels/src/unary.cu
Original file line number Diff line number Diff line change
@@ -97,7 +97,6 @@ __device__ T sign_(T t) {
}


#if __CUDA_ARCH__ >= 800
#include "cuda_fp8.h"
#include "cuda_bf16.h"

@@ -152,9 +151,7 @@ UNARY_OP(__nv_fp8_e4m3, usilu_fp8_e4m3, __nv_fp8_e4m3(silu_fwd(F8E4M3_TO_FLOAT(x
UNARY_OP1(__nv_fp8_e4m3, upowf_fp8_e4m3, powg(x, param))
UNARY_OP(__nv_fp8_e4m3, usign_fp8_e4m3, __nv_fp8_e4m3(sign_(F8E4M3_TO_FLOAT(x))))
UNARY_OP(__nv_fp8_e4m3, usigmoid_fp8_e4m3, __nv_fp8_e4m3(sigmoid_fwd(F8E4M3_TO_FLOAT(x))))
#endif

#if __CUDA_ARCH__ >= 530
UNARY_OP(__half, ucopy_f16, x)
UNARY_OP(__half, uneg_f16, -x)
UNARY_OP(__half, urecip_f16, recipg(x))
@@ -179,7 +176,6 @@ UNARY_OP(__half, usilu_f16, silu_fwd(x))
UNARY_OP1(__half, upowf_f16, powg(x, param))
UNARY_OP(__half, usign_f16, sign_(x))
UNARY_OP(__half, usigmoid_f16, sigmoid_fwd(x))
#endif

UNARY_OP(uint8_t, ucopy_u8, x)
UNARY_OP(uint32_t, ucopy_u32, x)
Loading