Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🐛 [Bug] Decomposing attention leads to shape errors (due to view op) in FLUX model #3333

Open
peri044 opened this issue Dec 19, 2024 · 3 comments · May be fixed by #3336
Open

🐛 [Bug] Decomposing attention leads to shape errors (due to view op) in FLUX model #3333

peri044 opened this issue Dec 19, 2024 · 3 comments · May be fixed by #3336
Labels
bug Something isn't working

Comments

@peri044
Copy link
Collaborator

peri044 commented Dec 19, 2024

Bug Description

After merging this PR : #3296, I see the following error

ValueError: Cannot view a tensor with shape torch.Size([s6, s2 + 4096, 24, 128]) and strides (3072*s2 + 12582912, 128, 128*s2 + 524288, 1) as a tensor with shape (s1, (s6*(s2 + 4096)//s1), 3072)!

While executing %view_52 : [num_users=1] = call_function[target=torch.ops.aten.view.default](args = (%transpose_10, [%sym_size_int_63, -1, 3072]), kwargs = {})
Original traceback:
File "/work/TensorRT/examples/dynamo/run_2.py", line 48, in forward
    return self.module.forward(
  File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/diffusers/models/transformers/transformer_flux.py", line 438, in forward
    hidden_states = block(
  File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/diffusers/models/transformers/transformer_flux.py", line 119, in forward
    attn_output = self.attn(
  File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/diffusers/models/attention_processor.py", line 490, in forward
    return self.processor(

To Reproduce

Here's the full script :

# %%
# Imports and Model Definition
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
import torch
import torch_tensorrt
from transformers import AutoModelForCausalLM, AutoTokenizer
from diffusers import FluxPipeline, FluxTransformer2DModel
from utils import export_llm, generate
from torch.export import Dim
from typing import Optional, Dict, Any
import logging
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
handler = logging.StreamHandler()
handler.setLevel(logging.DEBUG)
logger.addHandler(handler)

import time
from contextlib import contextmanager

@contextmanager
def timer(logger, name:str):
    logger.info(f"{name} section Start...")
    start = time.time()
    yield
    end = time.time()
    logger.info(f"{name} section End...")
    logger.info(f"{name} section elapsed time: {end - start} seconds")

class MyModule(torch.nn.Module):
    def __init__(self, module):
        super().__init__()
        self.module = module

    def forward(self,
                hidden_states: torch.Tensor,
                encoder_hidden_states: torch.Tensor = None,
                pooled_projections: torch.Tensor = None,
                timestep: torch.LongTensor = None,
                img_ids: torch.Tensor = None,
                txt_ids: torch.Tensor = None,
                guidance: torch.Tensor = None,
                joint_attention_kwargs: Optional[Dict[str, Any]] = None,
                return_dict: bool = False, **kwargs):


        return self.module.forward(
            hidden_states,
            encoder_hidden_states,
            pooled_projections,
            timestep,
            img_ids,
            txt_ids,
        )

def wrap_pipeline_transformer_call(instance, prompt, max_sequence_length):
    from unittest.mock import patch

# Assume `instance` is your class instance containing the `__call__` method

# Use patch.object to mock the __call__ method of self.transformer
    with patch.object(instance.transformer, 'forward', wraps=instance.transformer.forward) as mock_transformer_call:
        # one step is enough for intercept the inputs
        image =instance(
                prompt,
                guidance_scale=0.0,
                num_inference_steps=1,
                max_sequence_length=max_sequence_length,
                generator=torch.Generator("cpu").manual_seed(0)
            ).images[0]


        # Access the call arguments of the first (or specific) call
        if mock_transformer_call.call_args_list:
            args, kwargs = mock_transformer_call.call_args_list[0]
            # Store the inputs in a tuple
            intercepted_inputs = (args, kwargs)
            
            # print("Intercepted args:", args)
            # print("Intercepted kwargs:", kwargs)
            return (args, kwargs)
        else:
            print("No calls were made to self.transformer.__call__")
            return (None, None)


if __name__ == "__main__":

    # config
    dryrun = False

    # parameter setting
    batch_size = 2
    max_seq_len = 256
    prompt = ["A cat holding a sign that says hello world" for _ in range(batch_size)]
    cuda_device = "cuda:0"
    device="cuda:0"
    with torch.no_grad():
        pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", 
                                            torch_dtype=torch.float16)
        pipe.to(device)
        
        example_inputs = (torch.randn((batch_size, 4096, 64), dtype=torch.float16).to(device),
                  torch.randn((batch_size, 256, 4096), dtype=torch.float16).to(device),
                  torch.randn((batch_size, 768), dtype=torch.float16).to(device),
                  torch.tensor([1., 1.], dtype=torch.float16).to(device),
                  torch.randn((batch_size, 4096, 3), dtype=torch.float16).to(device),
                  torch.randn((batch_size, 256, 3), dtype=torch.float16).to(device),
        )
        BATCH = Dim("batch", min=1, max=batch_size)
        SEQ_LEN = Dim("seq_len", min=1, max=max_seq_len)
        dynamic_shapes = ({0 : BATCH}, 
                        {0 : BATCH, 1 : SEQ_LEN},
                        {0 : BATCH},
                        {0 : BATCH},
                        {0 : BATCH},
                        {0 : BATCH, 1 : SEQ_LEN},
                        )
        free, total = torch.cuda.mem_get_info(cuda_device)
        print(f"1 Free mem: {free}, Total mem: {total}")
        # breakpoint()
        with timer(logger=logger, name="ep_gen"):
                model = MyModule(pipe.transformer).eval().half()#.to(device)
                logger.info("Directly use _export because torch.export.export doesn't work")
                # This API is used to express the constraint violation guards as asserts in the graph.
                from torch.export._trace import _export
                ep = _export(
                    model,
                    args=example_inputs, 
                    dynamic_shapes=dynamic_shapes,
                    strict=False,
                    allow_complex_guards_as_runtime_asserts=True,
                )
        free, total = torch.cuda.mem_get_info(cuda_device)
        print(f"2 Free mem: {free}, Total mem: {total}")
        # breakpoint()
        logger.info(f"Generating TRT engine now, dryrun={dryrun}...")
        # print("Generating TRT engine now...")
        #TODO: if some non-tensor input, do we still need to provide them.
        with timer(logger, "trt_gen"):
            with torch_tensorrt.logging.debug():
                trt_start = time.time()
                trt_model = torch_tensorrt.dynamo.compile(
                                ep,
                                inputs=list(example_inputs),
                                enabled_precisions={torch.float32},
                                truncate_double=True,
                                device=torch.device(cuda_device),
                                disable_tf32=True,
                                use_explicit_typing=True,
                                dryrun=dryrun,
                                debug=True,
                                use_fp32_acc=True,
                            )
                trt_end = time.time()
        free, total = torch.cuda.mem_get_info(cuda_device)
        print(f"3 Free mem: {free}, Total mem: {total}")
        breakpoint()
        del pipe
        del ep
        del model

        free, total = torch.cuda.mem_get_info(cuda_device)
        print(f"4 Free mem: {free}, Total mem: {total}")
        breakpoint()
        import gc
        gc.collect()
        torch.cuda.empty_cache()
        example_inputs_cuda = [input.cuda() for input in example_inputs]
        with timer(logger, "trt_save"):
            try:
                breakpoint()
                trt_ep = torch.export.export(trt_model, args=example_inputs_cuda,
                                    dynamic_shapes=dynamic_shapes, strict=False)
                torch.export.save(trt_ep, "trt.ep")
            except Exception as e:
                import traceback
                # Capture the full traceback
                tb = traceback.format_exc()
                logger.warning("An error occurred. Here's the traceback:")
                # print(tb)
                logger.warning(tb)
                breakpoint()
                torch_tensorrt.save(trt_model, "trt.ep")

Expected behavior

Environment

Build information about Torch-TensorRT can be found by turning on debug messages

  • Torch-TensorRT Version (e.g. 1.0.0):
  • PyTorch Version (e.g. 1.0):
  • CPU Architecture:
  • OS (e.g., Linux):
  • How you installed PyTorch (conda, pip, libtorch, source):
  • Build command you used (if compiling from source):
  • Are you using local sources or building from archives:
  • Python version:
  • CUDA version:
  • GPU models and configuration:
  • Any other relevant information:

Additional context

@peri044 peri044 added the bug Something isn't working label Dec 19, 2024
@peri044
Copy link
Collaborator Author

peri044 commented Dec 19, 2024

Full log after modifying flux model source code (to just single transformer layer in diffusers library) is here:
full_log.txt

@HolyWu HolyWu linked a pull request Dec 22, 2024 that will close this issue
7 tasks
@peri044
Copy link
Collaborator Author

peri044 commented Jan 2, 2025

@HolyWu Thanks for raising the PR again. Can you let me know which lines in the code fixed this issue ?

@HolyWu
Copy link
Contributor

HolyWu commented Jan 2, 2025

The view_to_reshape lowering pass is replaced with decomposition, because using lowering pass is too late before the error being triggered.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants