Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Test: generate with torch.compile(model.forward) as a fast test #34544

Merged
merged 13 commits into from
Jan 28, 2025

Conversation

gante
Copy link
Member

@gante gante commented Oct 31, 2024

What does this PR do?

Follow-up to #34464

This PR:

  1. Converts test_generate_compile_model_forward to a fast test. This means we will check generate with torch.compile(model.forward) at each commit on ALL models that support StaticCache 💛
  2. Fixes failing cases of test_generate_compile_model_forward whenever possible
  3. Tags models with _supports_static_cache = False #Reason when the model doesn't support torch.compile(model.forward)

py.test tests/models/ -k test_generate_compile is all green, takes ~2 mins to run on all models on my machine

@gante gante requested review from ydshieh and ArthurZucker October 31, 2024 18:25
Copy link
Collaborator

@ydshieh ydshieh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Love this!

Q: Is it really fast ...?

Remark: I feel get_max_cache_length is a better name than get_max_cache_shape but OK I know not great to change name all the time.

@gante
Copy link
Member Author

gante commented Oct 31, 2024

Q: Is it really fast ...?

@ydshieh yes :D
Screenshot 2024-10-31 at 18 41 14

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't mind, tho I don't think our priority should be this (full compile vs compile forward in generate!) + I don't see the test being run in the CI! 🤗

@ArthurZucker
Copy link
Collaborator

Could you just make sure it's run

@ydshieh2
Copy link

ydshieh2 commented Nov 5, 2024

We need to remove @require_torch_gpu too for def test_generate_compile

tests/generation/test_utils.py Outdated Show resolved Hide resolved
@ydshieh
Copy link
Collaborator

ydshieh commented Jan 23, 2025

Before merge, feel free to ping me for a check for (if there is any) flakyness :-) or anything you think I can double check again.

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks can ignore my comments and merge 🤗

Comment on lines 359 to 360
elif isinstance(past_key_values, HybridCache):
target_length = past_key_values.get_max_cache_shape()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it possible for the HybridCache to inherit from Static cache?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We might just need an extra class that says CompileCompatible , someone wanted is_static attr˜!

@gante
Copy link
Member Author

gante commented Jan 27, 2025

(sorry, the PR is not ready yet, a few cases are still failing 👀 I didn't mean to request a review)

@gante
Copy link
Member Author

gante commented Jan 28, 2025

Now it's working on all models, including encoder-decoder + cache 🤗

It's not too heavy on our CI, it should add ~2 mins if all models are run. And it should prevent us from many headaches! As we can see in diff, we had compilation enabled for a bunch of models that don't support it.

Screenshot 2025-01-28 at 12 48 05

@gante gante merged commit ece8c42 into huggingface:main Jan 28, 2025
25 checks passed
@gante gante deleted the generate_forward_compile_fix branch January 28, 2025 14:10
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants