Skip to content

Commit

Permalink
Fix Mixtral-related issues (#570)
Browse files Browse the repository at this point in the history
This PR fixes problems related to #569:
- block initialization
- throughput calculation and cache usage
- mixtral in tests

Beam search is removed for Mixtral and Llama for now. Those models use DynamicCache, which requires special function to change: (see https://github.com/huggingface/transformers/blob/main/src/transformers/cache_utils.py#L161)

---------

Co-authored-by: Max Ryabinin <[email protected]>
  • Loading branch information
artek0chumak and mryab authored Apr 10, 2024
1 parent d2fcbbc commit d6f4f80
Show file tree
Hide file tree
Showing 13 changed files with 79 additions and 34 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/run-tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ jobs:
- { model: 'Maykeye/TinyLLama-v0', os: 'ubuntu', python-version: '3.11' }
- { model: 'Maykeye/TinyLLama-v0', os: 'macos', python-version: '3.10' }
- { model: 'Maykeye/TinyLLama-v0', os: 'macos', python-version: '3.11' }
- { model: 'artek0chumak/TestMixtral', os: 'ubuntu', python-version: '3.8' }
- { model: 'artek0chumak/TestMixtral', os: 'ubuntu', python-version: '3.11' }
fail-fast: false
runs-on: ${{ matrix.os }}-latest
timeout-minutes: 20
Expand Down
2 changes: 1 addition & 1 deletion src/petals/client/remote_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def update_seen(self, new_seen: int) -> None:
self.seen_tokens += new_seen

def reorder_cache(self, beam_idx):
pass
raise NotImplementedError("Beam search reordering is not implemented yet")


_skipped_tokens = ContextVar("skipped_tokens", default=0)
Expand Down
6 changes: 6 additions & 0 deletions src/petals/models/bloom/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
from transformers.models.bloom.modeling_bloom import BloomBlock, BloomModel, build_alibi_tensor

from petals.utils.misc import is_dummy


class WrappedBloomBlock(BloomBlock):
def forward(
Expand All @@ -22,6 +24,10 @@ def forward(
):
assert attention_mask is None, "Non-causal attention masks are not supported yet"
batch_size, seq_length = hidden_states.shape[:2]
if layer_past is not None and is_dummy(layer_past[0]):
# Bloom cannot use cache if it was misconsctructed(e.g. Dummy tensors)
# In this case, fallback to the old code:
layer_past = None
past_length = 0 if layer_past is None else layer_past[0].shape[-1]
seq_length_with_past = seq_length + past_length
attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
Expand Down
12 changes: 6 additions & 6 deletions src/petals/models/mixtral/block.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
from typing import Optional, Tuple

import torch
Expand Down Expand Up @@ -33,16 +34,15 @@ def forward(
past_key_values_length = 0

past_key_value = layer_past

if past_key_value is not None:
past_key_values_length = past_key_value[0].shape[2]
seq_length_with_past = seq_length_with_past + past_key_values_length
_past_key_value = self._reorder_cache_from_bloom(past_key_value, batch_size, past_key_values_length)
past_key_value = DynamicCache()
for idx in range(self.layer_idx):
past_key_value.update(
torch.empty(_past_key_value[0].size()), torch.empty(_past_key_value[1].size()), idx
)
past_key_value.update(_past_key_value[0], _past_key_value[1], self.layer_idx)
past_key_value.key_cache = [torch.empty(0) for _ in range(self.layer_idx)] + [_past_key_value[0]]
past_key_value.value_cache = [torch.empty(0) for _ in range(self.layer_idx)] + [_past_key_value[1]]
past_key_value._seen_tokens = past_key_values_length

if self._attn_implementation == "flash_attention_2":
# 2d mask is passed through the layers
Expand Down Expand Up @@ -83,7 +83,7 @@ def forward(

if use_cache:
present_key_value = outputs[-1]
present_key_value = present_key_value.to_legacy_cache()[self.layer_idx]
present_key_value = present_key_value[self.layer_idx]
present_key_value = self._reorder_cache_to_bloom(present_key_value, batch_size, seq_length_with_past)
outputs = outputs[:-1] + (present_key_value,)

Expand Down
21 changes: 15 additions & 6 deletions src/petals/models/mixtral/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,14 +122,20 @@ def forward(
def word_embeddings(self) -> nn.Embedding: # For compatibility with RemoteGenerationMixin
return self.embed_tokens

@property
def word_embeddings_layernorm(self) -> nn.Module: # For compatibility with RemoteGenerationMixin in tests
return nn.Identity()

@property
def h(self) -> RemoteSequential: # For compatibility with RemoteGenerationMixin
return self.layers

@property
def ln_f(self) -> nn.Module: # For compatibility with RemoteGenerationMixin in tests
return self.norm

class DistributedMixtralForCausalLM(
DefaultRevisionMixin, FromPretrainedMixin, RemoteGenerationMixin, MixtralForCausalLM
):

class DistributedMixtralForCausalLM(FromPretrainedMixin, RemoteGenerationMixin, MixtralForCausalLM):
_keys_to_ignore_on_load_missing = DistributedMixtralModel._keys_to_ignore_on_load_missing
_keys_to_ignore_on_load_unexpected = DistributedMixtralModel._keys_to_ignore_on_load_unexpected

Expand All @@ -151,9 +157,12 @@ def transformer(self) -> DistributedMixtralModel: # For compatibility with Remo
return self.model


class DistributedMixtralForSequenceClassification(
DefaultRevisionMixin, FromPretrainedMixin, MixtralForSequenceClassification
):
class DistributedMixtralForSequenceClassification(FromPretrainedMixin, MixtralForSequenceClassification):
_keys_to_ignore_on_load_missing = DistributedMixtralModel._keys_to_ignore_on_load_missing
_keys_to_ignore_on_load_unexpected = DistributedMixtralModel._keys_to_ignore_on_load_unexpected

config_class = DistributedMixtralConfig

def __init__(self, config: DistributedMixtralConfig):
MixtralPreTrainedModel.__init__(self, config)
self.num_labels = config.num_labels
Expand Down
17 changes: 15 additions & 2 deletions src/petals/server/block_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@

import torch
from accelerate import init_empty_weights
from transformers import PretrainedConfig
from transformers import PretrainedConfig, PreTrainedModel

from petals.models.mixtral.block import WrappedMixtralBlock
from petals.utils.convert_block import QuantType
from petals.utils.misc import get_size_in_bytes

Expand Down Expand Up @@ -32,7 +33,7 @@ def get_block_size(
), 'get_block_size(..., location="memory") requires to specify dtype and quant_type for calculations'

with init_empty_weights(include_buffers=True):
block = config.block_class(config)
block = get_model_block(config)
n_params = sum(param.numel() for param in block.parameters())

if location == "memory":
Expand All @@ -50,3 +51,15 @@ def get_block_size(
bytes_per_value = get_size_in_bytes(dtype)

return round(n_params * bytes_per_value * (1 + eps))


def get_model_block(config, layer_idx: int = 0):
"""
The function to create a model block based on the block class
kwargs argument **only** is necessary for specific classes, like Mixtral.
They will not be passed to other block constructors.
"""
if config.block_class == WrappedMixtralBlock:
config = PreTrainedModel._autoset_attn_implementation(config)
return config.block_class(config, layer_idx)
return config.block_class(config)
8 changes: 2 additions & 6 deletions src/petals/server/from_pretrained.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

from petals.constants import DTYPE_MAP
from petals.models.mixtral import WrappedMixtralBlock
from petals.server.block_utils import resolve_block_dtype
from petals.server.block_utils import get_model_block, resolve_block_dtype
from petals.utils.auto_config import AutoDistributedConfig
from petals.utils.disk_cache import DEFAULT_CACHE_DIR, allow_cache_reads, allow_cache_writes, free_disk_space_for
from petals.utils.hf_auth import always_needs_auth
Expand Down Expand Up @@ -52,11 +52,7 @@ def load_pretrained_block(
torch_dtype = resolve_block_dtype(config, torch_dtype)

with init_empty_weights():
if config.block_class == WrappedMixtralBlock:
config = PreTrainedModel._autoset_attn_implementation(config)
block = config.block_class(config, block_index)
else:
block = config.block_class(config)
block = get_model_block(config, layer_idx=block_index)

block_prefix = f"{config.block_prefix}.{block_index}."
state_dict = _load_state_dict_from_repo(
Expand Down
18 changes: 13 additions & 5 deletions src/petals/server/throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@
from hivemind.utils.logging import get_logger
from transformers import PretrainedConfig

from petals.server.block_utils import resolve_block_dtype
from petals.server.block_utils import get_model_block, resolve_block_dtype
from petals.utils.convert_block import QuantType, convert_block
from petals.utils.disk_cache import DEFAULT_CACHE_DIR
from petals.utils.misc import DUMMY_KEY_PAST

logger = get_logger(__name__)

Expand Down Expand Up @@ -201,18 +202,25 @@ def measure_compute_rps(
if not tensor_parallel_devices:
tensor_parallel_devices = (device,)
with torch.inference_mode():
block = config.block_class(config).to(dtype)
block = get_model_block(config)
block = block.to(dtype)
block = convert_block(block, 0, config, tensor_parallel_devices, device, quant_type=quant_type, freeze=True)

cache = None
cache = (DUMMY_KEY_PAST.to(dtype), DUMMY_KEY_PAST.to(dtype))
elapsed = 0
dummy_input = torch.randn(1, n_tokens, config.hidden_size, device=device, dtype=dtype)
_, cache = block.forward(dummy_input, use_cache=True) # Skip the 1st step to exclude the initialization time

# Skip the 1st step to exclude the initialization time
def step(cache_):
outputs = block.forward(dummy_input, use_cache=inference, layer_past=cache_ if inference else None)
return outputs[1] if inference else None

cache = step(cache)
synchronize(device)

start_time = time.perf_counter()
for _ in range(n_steps):
_, cache = block.forward(dummy_input, use_cache=True, layer_past=cache if inference else None)
cache = step(cache)
synchronize(device)
elapsed = time.perf_counter() - start_time
device_rps = n_steps * n_tokens / elapsed
Expand Down
2 changes: 2 additions & 0 deletions src/petals/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

DUMMY_INT64 = torch.empty(0, dtype=torch.int64)

DUMMY_KEY_PAST = torch.empty((0, 0, 0))


def is_dummy(tensor: torch.Tensor) -> bool:
return tensor.numel() == 0
Expand Down
4 changes: 2 additions & 2 deletions src/petals/utils/peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from safetensors.torch import load_file
from transformers.utils import get_file_from_repo

from petals.server.block_utils import resolve_block_dtype
from petals.server.block_utils import get_model_block, resolve_block_dtype
from petals.utils.convert_block import QuantType
from petals.utils.disk_cache import allow_cache_reads, allow_cache_writes, free_disk_space_for
from petals.utils.misc import get_size_in_bytes
Expand Down Expand Up @@ -273,7 +273,7 @@ def estimate_adapter_memory_per_block(
) -> int:
"""Get the number of extra bytes used to store a set of adapters per given block"""
with init_empty_weights(include_buffers=True):
block = block_config.block_class(block_config)
block = get_model_block(block_config)
base_block_parameters = sum(p.numel() for p in block.parameters())
create_lora_adapter(block, quant_type=QuantType.NONE)

Expand Down
9 changes: 6 additions & 3 deletions tests/test_chained_calls.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from petals import AutoDistributedConfig
from petals.client.remote_sequential import RemoteSequential
from petals.server.from_pretrained import load_pretrained_block
from petals.utils.misc import DUMMY_KEY_PAST
from test_utils import *


Expand Down Expand Up @@ -54,12 +55,14 @@ def test_chained_inference_exact_match(atol_inference=1e-4):
outputs_inference.append(sess.step(inputs[:, i : i + 1, :]))
outputs_inference = torch.cat(outputs_inference, dim=1)

dtype = torch.float32
ref_blocks = [
load_pretrained_block(MODEL_NAME, 3, torch_dtype=torch.float32),
load_pretrained_block(MODEL_NAME, 4, torch_dtype=torch.float32),
load_pretrained_block(MODEL_NAME, 3, torch_dtype=dtype),
load_pretrained_block(MODEL_NAME, 4, torch_dtype=dtype),
]
outputs_ref = []
caches = [None, None]
cache = (DUMMY_KEY_PAST.to(dtype), DUMMY_KEY_PAST.to(dtype))
caches = [cache, cache]
for i in range(inputs.shape[1]):
new_caches = []
hidden_states = inputs[:, i : i + 1, :]
Expand Down
4 changes: 4 additions & 0 deletions tests/test_full_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,10 @@ def test_sampling(tokenizer, model, ref_model, max_new_tokens=10):
), f"Sampling is not identical to HF with {options=}, {multiple_calls=}, {inputs.shape=}"


@pytest.mark.skipif(
"bloom" not in MODEL_NAME.lower(),
reason="Mixtral and Llama use DynamicCache, which can change based on beam search choices",
)
@pytest.mark.forked
def test_beam_search_generation(tokenizer, model, ref_model, max_new_tokens=4, num_beams=5):
inputs = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"]
Expand Down
8 changes: 5 additions & 3 deletions tests/test_optimized_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from transformers.models.falcon.modeling_falcon import FalconDecoderLayer, FalconModel, build_alibi_tensor
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel

from petals.server.block_utils import get_model_block
from petals.utils.auto_config import AutoDistributedConfig
from petals.utils.convert_block import QuantType, convert_block
from test_utils import MODEL_NAME
Expand Down Expand Up @@ -195,8 +196,9 @@ def test_optimized_block(device):
dtype = torch.bfloat16
quant_type = QuantType.NONE

block = config.block_class(config).to(dtype)
block = convert_block(block, 1, config, tensor_parallel_devices, device, quant_type=quant_type, freeze=True)
block_idx = 1
block = get_model_block(config, layer_idx=block_idx).to(dtype)
block = convert_block(block, block_idx, config, tensor_parallel_devices, device, quant_type=quant_type, freeze=True)

if config.model_type == "falcon":
unopt_block = UnoptimizedWrappedFalconBlock(config).to(dtype)
Expand All @@ -206,7 +208,7 @@ def test_optimized_block(device):
pytest.skip(f"This test is not applicable to {config.model_type} models")

unopt_block = convert_block(
unopt_block, 1, config, tensor_parallel_devices, device, quant_type=quant_type, freeze=True
unopt_block, block_idx, config, tensor_parallel_devices, device, quant_type=quant_type, freeze=True
)

unopt_block.load_state_dict(block.state_dict())
Expand Down

0 comments on commit d6f4f80

Please sign in to comment.