diff --git a/dev/cuda/matmul_forward.cu b/dev/cuda/matmul_forward.cu index b4794903a..ec13805a3 100644 --- a/dev/cuda/matmul_forward.cu +++ b/dev/cuda/matmul_forward.cu @@ -72,11 +72,9 @@ __global__ void matmul_forward_kernel1(float* out, int bt = blockIdx.x * blockDim.x + threadIdx.x; int oc = blockIdx.y * blockDim.y + threadIdx.y; if (bt < BT && oc < OC) { - int b = bt / BT; - int t = bt % BT; float val = (bias != NULL) ? bias[oc] : 0.0f; - const float* wrow = weight + oc*C; - const float* inp_bt = inp + b * BT * C + t * C; + const float* wrow = weight + oc * C; + const float* inp_bt = inp + bt * C; for (int i = 0; i < C; i++) { val += inp_bt[i] * wrow[i]; }