-
Notifications
You must be signed in to change notification settings - Fork 27.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Test: generate with torch.compile(model.forward)
as a fast test
#34544
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Love this!
Q: Is it really fast ...?
Remark: I feel get_max_cache_length
is a better name than get_max_cache_shape
but OK I know not great to change name all the time.
@ydshieh yes :D |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't mind, tho I don't think our priority should be this (full compile vs compile forward in generate!) + I don't see the test being run in the CI! 🤗
Could you just make sure it's run |
We need to remove |
Before merge, feel free to ping me for a check for (if there is any) flakyness :-) or anything you think I can double check again. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks can ignore my comments and merge 🤗
elif isinstance(past_key_values, HybridCache): | ||
target_length = past_key_values.get_max_cache_shape() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is it possible for the HybridCache to inherit from Static cache?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We might just need an extra class that says CompileCompatible
, someone wanted is_static
attr˜!
(sorry, the PR is not ready yet, a few cases are still failing 👀 I didn't mean to request a review) |
What does this PR do?
Follow-up to #34464
This PR:
test_generate_compile_model_forward
to a fast test. This means we will check generate withtorch.compile(model.forward)
at each commit on ALL models that supportStaticCache
💛test_generate_compile_model_forward
whenever possible_supports_static_cache = False #Reason
when the model doesn't supporttorch.compile(model.forward)
✅
py.test tests/models/ -k test_generate_compile
is all green, takes ~2 mins to run on all models on my machine