-
Notifications
You must be signed in to change notification settings - Fork 426
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: added k, v cache for inference speed up #7
base: main
Are you sure you want to change the base?
Conversation
Thanks for the implementation! What kind of speedups did you get with this and did you get an identical output to the non-kv cache version? Just FYI, I'm going to leave this unmerged to keep the implementation as simple as possible. However, will keep this PR open if people want to reference it in the future. There's also an inference optimization section in my blog post with some further resources to read up on. |
Yes, the Output is identical. I am seeing a 25% speedup of CPU. the most powerful machines on the planet.
The computer is a machine that can perform complex calculations, and it can perform these calculations in a way that is very similar to the human brain. Yeah, it makes sense to not merge it. Probably, we can create another file |
Hi @immortal3 I love the minimal implementation I'm having trouble reproducing the 25% speedup though. I've been using |
@clam004 i don't remember exactly how I ended as 25% speedup but it was definitely not a scientific one. 😄 The speedup number will heavily rely on the combination of CPU/Memory and the length of the input tokens. So, I think you might not be getting the exact number 25%, but try feeding a sufficiently longer sequence that should definitely indicate some performance improvement compared to a normal forward pass with KV cache. On the proper comparison side, I am not sure if it would be worth it (time-wise) at this point to do it thoroughly. |
for _ in tqdm(range(n_tokens_to_generate), "generating"): # auto-regressive decode loop | ||
logits = gpt2(inputs, **params, n_head=n_head) # model forward pass | ||
logits, kvcache = gpt2(inputs, **params, n_head=n_head, kvcache=kvcache) # model forward pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The main benefit of KV caching is that you don't need to recalculate the MLPs again for the tokens you already calculated the forward for, and so in the decoding phase you only pass the new token as input to the network.
You should only pass the next_id
as input in the decoding phase. In prefill phase, the initial inputs
should be passed. checkout https://github.com/pytorch-labs/gpt-fast/blob/main/generate.py#L72 or https://github.com/meta-llama/llama/blob/main/llama/generation.py#L187C51-L187C59 for an example.
more: https://www.perplexity.ai/search/what-should-be-the-input-to-th-bsYpXZiuRFinjT11Ck33EA#0
wpe_out = wpe[range(len(inputs))] | ||
else: | ||
wpe_out = wpe[[len(inputs)-1]] | ||
inputs = [inputs[-1]] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@panaali You're correct, if kvcache is there then only the last token should be passed. But this is I being lazy and don't want to change function signatures. So, I am doing it inside function. I just use the last token as input if kvcache is there.
Hi, @jaymody, Awesome blog post. I was interested in learning kvcache during inference and searched for it but existing articles on kvcache don't focus on the implementation part of it. So, I decided to implement it in picoGPT.
Are you interested in writing a post for optimization inference time? I would love to collaborate on it.