Skip to content

Commit

Permalink
Use attention masks for TGI generation (#264)
Browse files Browse the repository at this point in the history
fix(tgi): use attention masks for generation

Attention masks were simply ignored during generation, leading to
gibberish output if the inputs of some slots were padded (which
typically happens when adding a new request during generation).
  • Loading branch information
dacorvo authored Oct 20, 2023
1 parent 77ba00e commit 571effd
Showing 1 changed file with 37 additions and 10 deletions.
Original file line number Diff line number Diff line change
@@ -105,6 +105,7 @@ def clear(self):
self._inputs = ""
self._generation_config = None
self._tokens = []
self._mask = []
self._selector = None
self._generated_tokens = 0
self._next_token_text = ""
@@ -157,16 +158,19 @@ def assign(self, request: Request, generation_config: GenerationConfig):
self._generation_config.max_new_tokens = request.stopping_parameters.max_new_tokens
# TODO: stop_sequences, ignore_eos_token

def reset(self, input_ids: torch.LongTensor, selector: TokenSelector):
def reset(self, input_ids: torch.LongTensor, attention_mask: torch.LongTensor, selector: TokenSelector):
"""Reset the slot for the next generation.
Args:
input_ids: (`torch.LongTensor`):
The new input_ids to use to generate the next token.
attention_mask: (`torch.LongTensor`):
The new attention_mask to use to generate the next token.
selector: (`optimum.neuron.generation.TokenSelector`):
An object implementing the updated token selection logic.
"""
self._tokens = input_ids.clone()
self._mask = attention_mask.clone()
self._selector = selector

def pause(self):
@@ -178,6 +182,10 @@ def pause(self):

def resume(self):
"""Mark the slot as ready for generation."""
if self._state == Slot.State.PAUSE and self.next_token is not None:
# The generation of this slot was inhibited during a prefill, but it
# already had a pending token, so we need to increase attention mask
self._mask = torch.cat([self._mask, torch.LongTensor([1])])
self._state = Slot.State.READY

def append(self, next_token: int, next_token_text: str):
@@ -196,6 +204,7 @@ def append(self, next_token: int, next_token_text: str):
The corresponding decoded text.
"""
self._tokens = torch.cat([self._tokens, torch.LongTensor([next_token])])
self._mask = torch.cat([self._mask, torch.LongTensor([1])])
self._generated_tokens += 1
# Now that a new token has been generated, we can append the previous one to the inputs
self._inputs += self._next_token_text
@@ -227,6 +236,10 @@ def generated_text(self) -> str:
def next_token(self) -> int:
return None if len(self._tokens) == 0 else self._tokens[-1]

@property
def attention_mask(self) -> torch.LongTensor:
return self._mask

@property
def max_token(self) -> int:
return self._generation_config.max_length
@@ -304,7 +317,8 @@ def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]:
# If needed truncate sequences to fit into the static dimensions
seq_length = min(padded_inputs.input_ids.shape[-1], self.model.max_length)
input_ids = padded_inputs.input_ids[:, :seq_length]
# Each slot must be reset with the padded inputs
attention_mask = padded_inputs.attention_mask[:, :seq_length]
# Each slot must be reset with the padded inputs and masks
for i, slot in enumerate(self.slots):
if slot.state != slot.state.EMPTY:
slot_input_ids = input_ids[i : i + 1, :]
@@ -313,15 +327,16 @@ def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]:
slot_input_ids, slot.generation_config, self.model, self.model.max_length
)
slot_input_ids = slot_input_ids.squeeze().type(torch.int64)
slot.reset(slot_input_ids, selector)
slot_attention_mask = attention_mask[i]
slot.reset(slot_input_ids, slot_attention_mask, selector)
# Clear KV cache
self.model.reset_generation()
# Pause previously active slots during generation.
# Their KV cache will be prefilled but new tokens will be ignored, as they
# have already been generated and sent back in the last decode.
for slot in active_slots:
slot.pause()
generation, next_batch = self._generate_token(batch.id, input_ids)
generation, next_batch = self._generate_token(batch.id, input_ids, attention_mask)
# Reactivate previously active slots for the next decode.
for slot in active_slots:
slot.resume()
@@ -345,16 +360,28 @@ def decode(self, batches: List[CachedBatch]) -> Tuple[List[Generation], CachedBa
# just carry on with decoding. We adopt the id of the first
# batch in the list as our next batch id.
next_batch_id = batches[0].id
# Construct input_ids from tokens generated by the last decode or prefill requests
empty = True
input_ids = torch.full([self.model.batch_size, 1], fill_value=self.tokenizer.eos_token_id, dtype=torch.int64)
# Reconstruct input_ids and attention_mask from slots
input_ids = None
attention_mask = None
for i, slot in enumerate(self.slots):
if slot.state != Slot.State.EMPTY:
if input_ids is None:
# Create blank inputs covering all slots (even empty ones)
input_ids = torch.full(
[self.model.batch_size, 1], fill_value=self.tokenizer.eos_token_id, dtype=torch.int64
)
# input_ids are simply the tokens generated by the last decode or prefill requests (other tokens are cached)
input_ids[i, 0] = slot.next_token
empty = False
if empty:
if attention_mask is None:
# Create default mask covering all slots (even empty ones)
attention_mask = torch.zeros(
[self.model.batch_size, slot.attention_mask.size(-1)], dtype=torch.int64
)
attention_mask[:, -1] = 1
attention_mask[i, :] = slot.attention_mask
if input_ids is None:
raise ValueError("Unable to decode tokens for non-prefilled batches (probably due to a previous failure)")
return self._generate_token(next_batch_id, input_ids, attention_mask=None)
return self._generate_token(next_batch_id, input_ids, attention_mask)

def _generate_token(
self, next_batch_id: int, input_ids: torch.LongTensor, attention_mask: Optional[torch.LongTensor] = None

0 comments on commit 571effd

Please sign in to comment.