Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unnecessary KV Cache Updates During Training Mode #35648

Open
2 of 4 tasks
Hannibal046 opened this issue Jan 13, 2025 · 3 comments
Open
2 of 4 tasks

Unnecessary KV Cache Updates During Training Mode #35648

Hannibal046 opened this issue Jan 13, 2025 · 3 comments
Labels

Comments

@Hannibal046
Copy link

System Info

  • transformers version: 4.48.0
  • Platform: Linux-5.4.0-58-generic-x86_64-with-glibc2.31
  • Python version: 3.11.10
  • Huggingface_hub version: 0.26.3
  • Safetensors version: 0.4.5
  • Accelerate version: 1.1.1
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.5.1+cu118 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using distributed or parallel set-up in script?:
  • Using GPU in script?:
  • GPU type: NVIDIA A100-PCIE-40GB

Who can help?

@ArthurZucker @muellerz @SunMarc

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

I've identified an issue regarding unnecessary KV cache updates during training, which affects all current LLM models in the library and impacts both memory efficiency and torch.compile compatibility.

Taking src/transformers/models/llama/modeling_llama.py as an example:

  1. The past_key_values is set as a DynamicCache object even during training:

    if use_cache and past_key_values is None:
    past_key_values = DynamicCache()

  2. This leads to unnecessary memory allocation for storing KV cache during the forward pass:

if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

  1. Additionally, this causes issues with torch.compile, resulting in multiple recompilations due to layer_idx guards (my internal test) :
[rank6]:V0113 13:51:42.251000 31893 site-packages/torch/_dynamo/guards.py:2813] [10/6] [__recompiles] Recompiling function forward in /weka-jd/prod/deepseek/permanent/shared/chengxin/workspace/research_project/lingua/apps/main/hf_transformers/modeling_transformer.py:315
[rank6]:V0113 13:51:42.251000 31893 site-packages/torch/_dynamo/guards.py:2813] [10/6] [__recompiles]     triggered by the following guard failure(s):
[rank6]:V0113 13:51:42.251000 31893 site-packages/torch/_dynamo/guards.py:2813] [10/6] [__recompiles]     - 10/5: L['self']._modules['self_attn'].layer_idx == 5              
[rank6]:V0113 13:51:42.251000 31893 site-packages/torch/_dynamo/guards.py:2813] [10/6] [__recompiles]     - 10/4: L['self']._modules['self_attn'].layer_idx == 4              
[rank6]:V0113 13:51:42.251000 31893 site-packages/torch/_dynamo/guards.py:2813] [10/6] [__recompiles]     - 10/3: L['self']._modules['self_attn'].layer_idx == 3              
[rank6]:V0113 13:51:42.251000 31893 site-packages/torch/_dynamo/guards.py:2813] [10/6] [__recompiles]     - 10/2: L['self']._modules['self_attn'].layer_idx == 2              
[rank6]:V0113 13:51:42.251000 31893 site-packages/torch/_dynamo/guards.py:2813] [10/6] [__recompiles]     - 10/1: L['self']._modules['self_attn'].layer_idx == 1              
[rank6]:V0113 13:51:42.251000 31893 site-packages/torch/_dynamo/guards.py:2813] [10/6] [__recompiles]     - 10/0: L['self']._modules['self_attn'].layer_idx == 0              

Workaround

A temporary solution is to explicitly set use_cache=False during training:

loss = model.forward(input_ids, labels, use_cache=False)

Expected behavior

no kv cache during training

@Hannibal046
Copy link
Author

Or we could change from:

use_cache = use_cache if use_cache is not None else self.config.use_cache

to:

if use_cache is None:
   use_cache = False if self.training else self.config.use_cache

@Rocketknight1
Copy link
Member

cc @gante for cache as well

@gante
Copy link
Member

gante commented Jan 14, 2025

@Hannibal046 👋

@Rocketknight1 and I chatted offline and we can't think of a reason why the cache should stay active while training. I'm going to update to the pattern as you wrote -- if the user doesn't manually specify the cache, only use the config file value (which is True by default) if not training

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

3 participants