Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unsupported: hasattr SkipFunctionVariable when i compile the mixtral model with muti-gpus #35623

Open
4 tasks
zyxiyy opened this issue Jan 11, 2025 · 2 comments
Open
4 tasks
Labels

Comments

@zyxiyy
Copy link

zyxiyy commented Jan 11, 2025

System Info

none

Who can help?

No response

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

import torch
from transformers import StaticCache
NUM_TOKENS_TO_GENERATE = 40
torch_device = "cuda"
from torch.nn.attention import SDPBackend, sdpa_kernel
def decode_one_tokens(model, cur_token, input_pos, cache_position, past_key_values):
logits = model(
cur_token,
position_ids=input_pos,
cache_position=cache_position,
past_key_values=past_key_values,
return_dict=False,
use_cache=True
)[0]
new_token = torch.argmax(logits[:, -1], dim=-1)[:, None]
return new_token
batch_size, seq_length = inputs["input_ids"].shape
with torch.no_grad():
past_key_values = StaticCache(
config=model.config, max_batch_size=1, max_cache_len=4096, device=torch_device, dtype=model.dtype,layer_device_map=layer_device_map,
)
cache_position = torch.arange(seq_length, device=torch_device)
generated_ids = torch.zeros(
batch_size, seq_length + NUM_TOKENS_TO_GENERATE + 1, dtype=torch.int, device=torch_device
)
generated_ids[:, cache_position] = inputs["input_ids"].to(torch_device).to(torch.int)

logits = model(
    **inputs, cache_position=cache_position, past_key_values=past_key_values,return_dict=False, use_cache=True
)[0]
next_token = torch.argmax(logits[:, -1], dim=-1)[:, None]
generated_ids[:, seq_length] = next_token[:, 0]

decode_one_tokens = torch.compile(decode_one_tokens, mode="reduce-overhead", fullgraph=True)
cache_position = torch.tensor([seq_length + 1], device=torch_device)
input_position=cache_position.clone
for _ in range(1, NUM_TOKENS_TO_GENERATE):
    with sdpa_kernel(SDPBackend.MATH):
        next_token = decode_one_tokens(model, next_token.clone(), input_position, cache_position, past_key_values)
        generated_ids[:, cache_position] = next_token.int()
    cache_position += 1
    input_position+=1

text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)

Expected behavior

Unsupported: hasattr SkipFunctionVariable to

from user code:
File "/tmp/ipykernel_1957076/1822748636.py", line 7, in decode_one_tokens
logits = model(
File "/home/bcds/.conda/envs/llm/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
File "/home/bcds/.conda/envs/llm/lib/python3.9/site-packages/accelerate/hooks.py", line 165, in new_forward
args, kwargs = module._hf_hook.pre_forward(module, *args, **kwargs)
File "/home/bcds/.conda/envs/llm/lib/python3.9/site-packages/accelerate/hooks.py", line 364, in pre_forward
return send_to_device(args, self.execution_device), send_to_device(
File "/home/bcds/.conda/envs/llm/lib/python3.9/site-packages/accelerate/utils/operations.py", line 184, in send_to_device
{
File "/home/bcds/.conda/envs/llm/lib/python3.9/site-packages/accelerate/utils/operations.py", line 185, in
k: t if k in skip_keys else send_to_device(t, device, non_blocking=non_blocking, skip_keys=skip_keys)
File "/home/bcds/.conda/envs/llm/lib/python3.9/site-packages/accelerate/utils/operations.py", line 149, in send_to_device
if is_torch_tensor(tensor) or hasattr(tensor, "to"):

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information

You can suppress this exception and fall back to eager by setting:
import torch._dynamo
torch._dynamo.config.suppress_errors = True

@zyxiyy zyxiyy added the bug label Jan 11, 2025
@Rocketknight1
Copy link
Member

It seems like you might be trying to compile the entire generate loop, which I don't think we support yet (cc @gante). Support should be coming soon, but in the meantime you might have to use the built-in generate() methods, or run your generation loop eagerly!

@zyxiyy
Copy link
Author

zyxiyy commented Jan 13, 2025

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants