We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
After merging this PR : #3296, I see the following error
ValueError: Cannot view a tensor with shape torch.Size([s6, s2 + 4096, 24, 128]) and strides (3072*s2 + 12582912, 128, 128*s2 + 524288, 1) as a tensor with shape (s1, (s6*(s2 + 4096)//s1), 3072)! While executing %view_52 : [num_users=1] = call_function[target=torch.ops.aten.view.default](args = (%transpose_10, [%sym_size_int_63, -1, 3072]), kwargs = {}) Original traceback: File "/work/TensorRT/examples/dynamo/run_2.py", line 48, in forward return self.module.forward( File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/diffusers/models/transformers/transformer_flux.py", line 438, in forward hidden_states = block( File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/diffusers/models/transformers/transformer_flux.py", line 119, in forward attn_output = self.attn( File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/diffusers/models/attention_processor.py", line 490, in forward return self.processor(
Here's the full script :
# %% # Imports and Model Definition # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ import torch import torch_tensorrt from transformers import AutoModelForCausalLM, AutoTokenizer from diffusers import FluxPipeline, FluxTransformer2DModel from utils import export_llm, generate from torch.export import Dim from typing import Optional, Dict, Any import logging logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) handler = logging.StreamHandler() handler.setLevel(logging.DEBUG) logger.addHandler(handler) import time from contextlib import contextmanager @contextmanager def timer(logger, name:str): logger.info(f"{name} section Start...") start = time.time() yield end = time.time() logger.info(f"{name} section End...") logger.info(f"{name} section elapsed time: {end - start} seconds") class MyModule(torch.nn.Module): def __init__(self, module): super().__init__() self.module = module def forward(self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor = None, pooled_projections: torch.Tensor = None, timestep: torch.LongTensor = None, img_ids: torch.Tensor = None, txt_ids: torch.Tensor = None, guidance: torch.Tensor = None, joint_attention_kwargs: Optional[Dict[str, Any]] = None, return_dict: bool = False, **kwargs): return self.module.forward( hidden_states, encoder_hidden_states, pooled_projections, timestep, img_ids, txt_ids, ) def wrap_pipeline_transformer_call(instance, prompt, max_sequence_length): from unittest.mock import patch # Assume `instance` is your class instance containing the `__call__` method # Use patch.object to mock the __call__ method of self.transformer with patch.object(instance.transformer, 'forward', wraps=instance.transformer.forward) as mock_transformer_call: # one step is enough for intercept the inputs image =instance( prompt, guidance_scale=0.0, num_inference_steps=1, max_sequence_length=max_sequence_length, generator=torch.Generator("cpu").manual_seed(0) ).images[0] # Access the call arguments of the first (or specific) call if mock_transformer_call.call_args_list: args, kwargs = mock_transformer_call.call_args_list[0] # Store the inputs in a tuple intercepted_inputs = (args, kwargs) # print("Intercepted args:", args) # print("Intercepted kwargs:", kwargs) return (args, kwargs) else: print("No calls were made to self.transformer.__call__") return (None, None) if __name__ == "__main__": # config dryrun = False # parameter setting batch_size = 2 max_seq_len = 256 prompt = ["A cat holding a sign that says hello world" for _ in range(batch_size)] cuda_device = "cuda:0" device="cuda:0" with torch.no_grad(): pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.float16) pipe.to(device) example_inputs = (torch.randn((batch_size, 4096, 64), dtype=torch.float16).to(device), torch.randn((batch_size, 256, 4096), dtype=torch.float16).to(device), torch.randn((batch_size, 768), dtype=torch.float16).to(device), torch.tensor([1., 1.], dtype=torch.float16).to(device), torch.randn((batch_size, 4096, 3), dtype=torch.float16).to(device), torch.randn((batch_size, 256, 3), dtype=torch.float16).to(device), ) BATCH = Dim("batch", min=1, max=batch_size) SEQ_LEN = Dim("seq_len", min=1, max=max_seq_len) dynamic_shapes = ({0 : BATCH}, {0 : BATCH, 1 : SEQ_LEN}, {0 : BATCH}, {0 : BATCH}, {0 : BATCH}, {0 : BATCH, 1 : SEQ_LEN}, ) free, total = torch.cuda.mem_get_info(cuda_device) print(f"1 Free mem: {free}, Total mem: {total}") # breakpoint() with timer(logger=logger, name="ep_gen"): model = MyModule(pipe.transformer).eval().half()#.to(device) logger.info("Directly use _export because torch.export.export doesn't work") # This API is used to express the constraint violation guards as asserts in the graph. from torch.export._trace import _export ep = _export( model, args=example_inputs, dynamic_shapes=dynamic_shapes, strict=False, allow_complex_guards_as_runtime_asserts=True, ) free, total = torch.cuda.mem_get_info(cuda_device) print(f"2 Free mem: {free}, Total mem: {total}") # breakpoint() logger.info(f"Generating TRT engine now, dryrun={dryrun}...") # print("Generating TRT engine now...") #TODO: if some non-tensor input, do we still need to provide them. with timer(logger, "trt_gen"): with torch_tensorrt.logging.debug(): trt_start = time.time() trt_model = torch_tensorrt.dynamo.compile( ep, inputs=list(example_inputs), enabled_precisions={torch.float32}, truncate_double=True, device=torch.device(cuda_device), disable_tf32=True, use_explicit_typing=True, dryrun=dryrun, debug=True, use_fp32_acc=True, ) trt_end = time.time() free, total = torch.cuda.mem_get_info(cuda_device) print(f"3 Free mem: {free}, Total mem: {total}") breakpoint() del pipe del ep del model free, total = torch.cuda.mem_get_info(cuda_device) print(f"4 Free mem: {free}, Total mem: {total}") breakpoint() import gc gc.collect() torch.cuda.empty_cache() example_inputs_cuda = [input.cuda() for input in example_inputs] with timer(logger, "trt_save"): try: breakpoint() trt_ep = torch.export.export(trt_model, args=example_inputs_cuda, dynamic_shapes=dynamic_shapes, strict=False) torch.export.save(trt_ep, "trt.ep") except Exception as e: import traceback # Capture the full traceback tb = traceback.format_exc() logger.warning("An error occurred. Here's the traceback:") # print(tb) logger.warning(tb) breakpoint() torch_tensorrt.save(trt_model, "trt.ep")
Build information about Torch-TensorRT can be found by turning on debug messages
conda
pip
libtorch
The text was updated successfully, but these errors were encountered:
Full log after modifying flux model source code (to just single transformer layer in diffusers library) is here: full_log.txt
Sorry, something went wrong.
@HolyWu Thanks for raising the PR again. Can you let me know which lines in the code fixed this issue ?
The view_to_reshape lowering pass is replaced with decomposition, because using lowering pass is too late before the error being triggered.
Successfully merging a pull request may close this issue.
Bug Description
After merging this PR : #3296, I see the following error
To Reproduce
Here's the full script :
Expected behavior
Environment
conda
,pip
,libtorch
, source):Additional context
The text was updated successfully, but these errors were encountered: