Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add EasyAnimateV5.1 text-to-video, image-to-video, control-to-video generation model #10626

Open
wants to merge 5 commits into
base: main
Choose a base branch
from

Conversation

bubbliiiing
Copy link

@bubbliiiing bubbliiiing commented Jan 22, 2025

What does this PR do?

This PR converts the EasyAnimateV5.1 model into a diffuser-supported inference model, including three complete pipelines and corresponding modules.

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@a-r-r-o-w

Copy link
Member

@a-r-r-o-w a-r-r-o-w left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the PR @bubbliiiing! This is in great shape and already mostly in the implementation style used in diffusers 🤗

I've left some comments from a quick look through the PR. Happy to help make any of the required changes to help bring the PR to completion

.gitignore Outdated Show resolved Hide resolved
src/diffusers/models/attention_processor.py Outdated Show resolved Hide resolved
encoder_hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
attn2: Attention = None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems similar to Flux/SD3/HunyuanVideo's Joint-attention processors that concatenate the visual and text tokens. Let's do it the same way as done here:

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this mean using add_q, add_k, add_v instead of using attn2?
image

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are two things we could do here:

  • Either convert the state dict of the original-format models (that you currently have on the HuggingFace Hub) and update them to diffusers-format (which would make attn2.to_q -> add_q_proj, attn2.to_k -> add_k_proj, attn2.to_v -> add_v_proj
  • Create a custom attention class similar to Attention and MochiAttention in which you are free to use layer naming of your choice (so basically keeping the same to_q, to_k and to_v.

The first approach is more closely aligned with diffusers code style but would require you to update multiple checkpoints -- but we are transitioning to a single file modeling format, so if you choose to go with second approach for convenience, that works for us as well. Essentially, irrespective of the design you choose, we need to make sure:

  • When the forward of a layer is called, it only takes tensors as input and produces tensors as output.
  • Taking intermediate layers as input to forward, or making calls to other layers out-of-order randomly, is not supported by our design style of different current/upcoming features

cc @DN6 here in case you have thoughts about the single file format and model-specific Attention classes

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I moved attn2 to the processor's init; does this meet the requirements?

src/diffusers/models/autoencoders/autoencoder_kl_magvit.py Outdated Show resolved Hide resolved
src/diffusers/models/autoencoders/autoencoder_kl_magvit.py Outdated Show resolved Hide resolved
src/diffusers/models/autoencoders/autoencoder_kl_magvit.py Outdated Show resolved Hide resolved
for name, module in self.named_children():
_set_3dgroupnorm_for_submodule(name, module)

def single_forward(self, x: torch.Tensor) -> torch.Tensor:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is very different from diffusers-style implementation of encoder/decoder. Could we follow the style as done in:

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, does this mean that I cannot use functions like set_padding_one_frame?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do I need to use this conv_cache in autoencoder_kl_mochi?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the model implementations, we usually only try to keep (atleast in the latest model integrations):

  • Submodel initializations
  • Forward method

So, unless a helper function like set_padding_one_frame is used in multiple locations, I would suggest directly substituting its code in the forward implementation. If a helper function is required, let's make it a private function by prefixing the function name with an underscore

The conv_cache saves a few computations when running the VAE encode/decode process from repeated frames that are used as padding. As such, it is not required to implement it if it is not needed for framewise encoding and decoding.

src/diffusers/models/autoencoders/autoencoder_kl_magvit.py Outdated Show resolved Hide resolved
src/diffusers/models/downsampling.py Outdated Show resolved Hide resolved
src/diffusers/models/normalization.py Outdated Show resolved Hide resolved
@yiyixuxu yiyixuxu added the roadmap Add to current release roadmap label Jan 22, 2025
@bubbliiiing
Copy link
Author

Sorry for not standardizing some parts; I will make the necessary modifications. Also, I would like to ask if I need to add test files in tests/pipelines and add documentation in docs/source/en/api/pipelines?

@a-r-r-o-w
Copy link
Member

Yes, we will need a test for all three pipelines as well as model tests in tests/models, as well as the documentation pages. Ofcourse, I will try to help you with everything :)

Also, congratulations on the release! I tried out the original repository example and the model is very good! 🎉

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
roadmap Add to current release roadmap
Projects
Status: In Progress
Development

Successfully merging this pull request may close these issues.

3 participants