diff --git a/gpt2.f90 b/gpt2.f90 index c915b80..b8cb1e3 100644 --- a/gpt2.f90 +++ b/gpt2.f90 @@ -61,46 +61,68 @@ function ffn(x, fc_w, fc_b, proj_w, proj_b) result(y) y = linear(gelu(linear(x, fc_w, fc_b)), proj_w, proj_b) end function -function attention(q, k, v, mask) result(y) -real(sp), intent(in) :: q(:,:), k(:,:), v(:,:), mask(:,:) -real(sp) :: y(size(v,1),size(q,2)) -real(sp) :: tmp(size(k,2),size(q,2)) +function attention(n_embd_head,n_seq,n_seq_x, q, k, v, mask) result(y) +integer, intent(in) :: n_embd_head, n_seq, n_seq_x +real(sp), intent(in) :: q(n_embd_head,n_seq_x), k(n_embd_head,n_seq), v(n_embd_head,n_seq), mask(n_seq,n_seq_x) +real(sp) :: y(n_embd_head,n_seq_x) +real(sp) :: tmp(n_seq,n_seq_x) !tmp = matmul(transpose(k), q) !call matmul_2d(transpose(k), q, tmp) call matmul_2d_t(k, q, tmp) -call matmul_2d(v, softmax(tmp / sqrt(real(size(q,1),sp)) + mask), y) +call matmul_2d(v, softmax(tmp / sqrt(real(n_embd_head,sp)) + mask), y) end function -function mha(n_seq, n_embd, x, attn_w, attn_b, proj_w, proj_b, n_head) & +function mha(n_seq, n_seq_x, n_embd, x, attn_w, attn_b, proj_w, proj_b, n_head, & + use_kv_cache, kv_cache) & result(y) -integer, intent(in) :: n_seq, n_embd -real(sp), intent(in) :: x(n_embd,n_seq), & +integer, intent(in) :: n_seq, n_seq_x, n_embd +real(sp), intent(in) :: x(n_embd,n_seq_x), & attn_w(3*n_embd,n_embd), attn_b(3*n_embd), & proj_w(n_embd,n_embd), proj_b(n_embd) +real(sp), intent(inout) :: kv_cache(n_embd,n_seq,2) integer, intent(in) :: n_head -real(sp) :: y(n_embd,n_seq) -real(sp) :: causal_mask(n_seq,n_seq) -real(sp) :: x2(3*n_embd,n_seq) +logical, intent(in) :: use_kv_cache +real(sp) :: y(n_embd,n_seq_x) +real(sp) :: causal_mask(n_seq,n_seq_x) +real(sp) :: x2(3*n_embd,n_seq_x) integer :: i, j ! Mask -do j = 1, n_seq -do i = 1, n_seq - if (i > j) then - causal_mask(i,j) = -1e10_sp - else - causal_mask(i,j) = 0 - end if -end do -end do +if (use_kv_cache) then + causal_mask = 0 +else + do j = 1, n_seq + do i = 1, n_seq + if (i > j) then + causal_mask(i,j) = -1e10_sp + else + causal_mask(i,j) = 0 + end if + end do + end do +end if x2 = linear(x, attn_w, attn_b) associate ( & q => x2((1-1)*n_embd+1:1*n_embd,:), & k => x2((2-1)*n_embd+1:2*n_embd,:), & v => x2((3-1)*n_embd+1:3*n_embd,:) & ) + if (use_kv_cache) then + kv_cache(:,n_seq,1) = k(:,1) + kv_cache(:,n_seq,2) = v(:,1) + else + kv_cache(:,:,1) = k + kv_cache(:,:,2) = v + end if +end associate +associate ( & + q => x2((1-1)*n_embd+1:1*n_embd,:), & + k => kv_cache(:,:,1), & + v => kv_cache(:,:,2) & + ) ! Perform attention over each head do i = 1, n_head y((i-1)*n_embd/n_head+1:i*n_embd/n_head,:) = attention( & + n_embd/n_head, n_seq, n_seq_x, & q((i-1)*n_embd/n_head+1:i*n_embd/n_head,:), & k((i-1)*n_embd/n_head+1:i*n_embd/n_head,:), & v((i-1)*n_embd/n_head+1:i*n_embd/n_head,:), & @@ -112,31 +134,32 @@ function mha(n_seq, n_embd, x, attn_w, attn_b, proj_w, proj_b, n_head) & end function -function transformer_block(x, mlp_fc_w, mlp_fc_b, mlp_proj_w, mlp_proj_b, & +function transformer_block(n_seq, n_seq_x, n_embd, x, mlp_fc_w, mlp_fc_b, mlp_proj_w, mlp_proj_b, & attn_w, attn_b, attn_proj_w, attn_proj_b, ln1_g, ln1_b, ln2_g, ln2_b, & - n_head) result(y) -real(sp), intent(in) :: x(:,:), & + n_head, use_kv_cache, kv_cache) result(y) +real(sp), intent(in) :: x(n_embd,n_seq_x), & mlp_fc_w(:,:), mlp_fc_b(:), & mlp_proj_w(:,:), mlp_proj_b(:), & attn_w(:,:), attn_b(:), attn_proj_w(:,:), attn_proj_b(:), & ln1_g(:), ln1_b(:), ln2_g(:), ln2_b(:) integer, intent(in) :: n_head -real(sp) :: y(size(x,1),size(x,2)) -integer :: n_seq, n_embd -n_embd = size(x,1) -n_seq = size(x,2) -y = x + mha(n_seq, n_embd, layer_norm(x, ln1_g, ln1_b, 1e-5_sp), & - attn_w, attn_b, attn_proj_w, attn_proj_b, n_head) +integer, intent(in) :: n_seq, n_seq_x, n_embd +real(sp) :: y(n_embd,n_seq_x) +logical, intent(in) :: use_kv_cache +real(sp), intent(inout) :: kv_cache(n_embd,n_seq,2) +y = x + mha(n_seq, n_seq_x, n_embd, layer_norm(x, ln1_g, ln1_b, 1e-5_sp), & + attn_w, attn_b, attn_proj_w, attn_proj_b, n_head, use_kv_cache, kv_cache) y = y + ffn(layer_norm(y, ln2_g, ln2_b, 1e-5_sp), & mlp_fc_w, mlp_fc_b, mlp_proj_w, mlp_proj_b) end function -function gpt2(n_vocab, n_ctx, n_seq, n_embd, n_layer, n_head, input, & +function gpt2(n_vocab, n_ctx, n_seq, n_seq_x, n_embd, n_layer, n_head, input, & wte, wpe, & mlp_fc_w, mlp_fc_b, mlp_proj_w, mlp_proj_b, & attn_w, attn_b, attn_proj_w, attn_proj_b, & - ln1_g, ln1_b, ln2_g, ln2_b, lnf_g, lnf_b) result(y) -integer, intent(in) :: n_vocab, n_ctx, n_seq, n_embd, n_layer, n_head + ln1_g, ln1_b, ln2_g, ln2_b, lnf_g, lnf_b, & + use_kv_cache, kv_cache) result(y) +integer, intent(in) :: n_vocab, n_ctx, n_seq, n_seq_x, n_embd, n_layer, n_head integer, intent(in) :: input(n_seq) real(sp), intent(in) :: wte(n_embd,n_vocab), wpe(n_embd,n_ctx), & mlp_fc_w(4*n_embd,n_embd,n_layer), mlp_fc_b(4*n_embd,n_layer), & @@ -146,19 +169,26 @@ function gpt2(n_vocab, n_ctx, n_seq, n_embd, n_layer, n_head, input, & ln1_b(n_embd,n_layer), ln1_g(n_embd,n_layer), & ln2_b(n_embd,n_layer), ln2_g(n_embd,n_layer), & lnf_b(n_embd), lnf_g(n_embd) -real(sp) :: y(n_vocab,n_seq) -real(sp) :: x(n_embd,n_seq) +logical, intent(in) :: use_kv_cache +real(sp), intent(inout) :: kv_cache(n_embd,n_seq,2,n_layer) +real(sp) :: y(n_vocab,n_seq_x) +real(sp) :: x(n_embd,n_seq_x) integer :: i -do i = 1, n_seq - x(:,i) = wte(:,input(i)+1) + wpe(:,i) -end do +if (use_kv_cache) then + i = n_seq + x(:,1) = wte(:,input(i)+1) + wpe(:,i) +else + do i = 1, n_seq + x(:,i) = wte(:,input(i)+1) + wpe(:,i) + end do +end if do i = 1, n_layer - x = transformer_block(x, & + x = transformer_block(n_seq, n_seq_x, n_embd, x, & mlp_fc_w(:,:,i), mlp_fc_b(:,i), & mlp_proj_w(:,:,i), mlp_proj_b(:,i), & attn_w(:,:,i), attn_b(:,i), attn_proj_w(:,:,i), attn_proj_b(:,i), & ln1_g(:,i), ln1_b(:,i), ln2_g(:,i), ln2_b(:,i), & - n_head) + n_head, use_kv_cache, kv_cache(:,:,:,i)) end do x = layer_norm(x, lnf_g, lnf_b, 1e-5) !y = matmul(transpose(wte), x) @@ -170,7 +200,7 @@ function generate(n_tokens_to_generate, & wte, wpe, & mlp_fc_w, mlp_fc_b, mlp_proj_w, mlp_proj_b, & attn_w, attn_b, attn_proj_w, attn_proj_b, & - ln1_g, ln1_b, ln2_g, ln2_b, lnf_g, lnf_b) result(output) + ln1_g, ln1_b, ln2_g, ln2_b, lnf_g, lnf_b, use_cache) result(output) integer, intent(in) :: n_vocab, n_ctx, n_seq, n_embd, n_layer, n_head, & n_tokens_to_generate integer, intent(in) :: input(n_seq) @@ -182,22 +212,37 @@ function generate(n_tokens_to_generate, & ln1_b(n_embd,n_layer), ln1_g(n_embd,n_layer), & ln2_b(n_embd,n_layer), ln2_g(n_embd,n_layer), & lnf_b(n_embd), lnf_g(n_embd) +logical, intent(in) :: use_cache integer :: output(n_tokens_to_generate) real(sp), allocatable :: logits(:,:) integer :: i +integer :: n_seq2, n_seq_x integer :: next_id integer, allocatable :: input2(:) +logical :: use_kv_cache +real(sp) :: kv_cache(n_embd,n_seq+n_tokens_to_generate,2,n_layer) allocate(input2(size(input))) input2 = input do i = 1, n_tokens_to_generate - allocate(logits(n_vocab, size(input2))) - logits = gpt2(n_vocab, n_ctx, size(input2), n_embd, n_layer, n_head, & + if (use_cache) then + use_kv_cache = (i > 1) ! Use cache for subsequent tokens + else + use_kv_cache = .false. + end if + n_seq2 = size(input2) + if (use_kv_cache) then + n_seq_x = 1 + else + n_seq_x = n_seq2 + end if + allocate(logits(n_vocab, n_seq_x)) + logits = gpt2(n_vocab, n_ctx, n_seq2, n_seq_x, n_embd, n_layer, n_head, & input2, & wte, wpe, & mlp_fc_w, mlp_fc_b, mlp_proj_w, mlp_proj_b, & attn_w, attn_b, attn_proj_w, attn_proj_b, & - ln1_g, ln1_b, ln2_g, ln2_b, lnf_g, lnf_b) - next_id = maxloc(logits(:,size(logits,2)), dim=1)-1 + ln1_g, ln1_b, ln2_g, ln2_b, lnf_g, lnf_b, use_kv_cache, kv_cache(:,:n_seq2,:,:)) + next_id = maxloc(logits(:,n_seq_x), dim=1)-1 print *, i, next_id input2 = [input2, next_id] deallocate(logits) diff --git a/main.f90 b/main.f90 index aa7e41d..42e8d0c 100644 --- a/main.f90 +++ b/main.f90 @@ -22,6 +22,7 @@ program gpt2 character(:), allocatable :: output_txt real(dp) :: t1, t2, t1o, t2o integer :: u +logical :: use_cache ! Load the model print "(a)", "Loading the model..." @@ -86,13 +87,14 @@ program gpt2 print "(a)", "Running model..." call cpu_time(t1) t1o = omp_get_wtime() +use_cache = .true. output = generate(n_tokens_to_generate, n_vocab, n_ctx, size(input), n_embd, & n_layer, n_head, & input, & wte, wpe, & mlp_fc_w, mlp_fc_b, mlp_proj_w, mlp_proj_b, & attn_w, attn_b, attn_proj_w, attn_proj_b, & - ln1_g, ln1_b, ln2_g, ln2_b, lnf_g, lnf_b) + ln1_g, ln1_b, ln2_g, ln2_b, lnf_g, lnf_b, use_cache) t2o = omp_get_wtime() call cpu_time(t2) print "(a,f8.3,a,f4.2,a)", " done. Time:", t2o-t1o, "s (", (t2-t1)/(t2o-t1o), "x)"