Skip to content

Commit

Permalink
Return logits from single step
Browse files Browse the repository at this point in the history
  • Loading branch information
g-eoj committed May 28, 2024
1 parent c7f2cf2 commit 069808b
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 32 deletions.
6 changes: 4 additions & 2 deletions vllm/engine/output_processor/single_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,14 +102,16 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
new_child_seq_id: int = next(self.seq_counter)
child = parent.fork(new_child_seq_id)
child.append_token_id(child_sample.output_token,
child_sample.logprobs)
child_sample.logprobs,
child_sample.output_logits)
child_seqs.append((child, parent))
# Continue the parent sequence for the last child sample.
# We reuse the parent sequence here to reduce redundant memory
# copies, especially when using non-beam search sampling methods.
last_child_sample = child_samples[-1]
parent.append_token_id(last_child_sample.output_token,
last_child_sample.logprobs)
last_child_sample.logprobs,
last_child_sample.output_logits)
child_seqs.append((parent, parent))

for seq, _ in child_seqs:
Expand Down
33 changes: 13 additions & 20 deletions vllm/model_executor/layers/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,8 @@ def forward(
# Get the logprobs query results.
prompt_logprobs, sample_logprobs = _get_logprobs(
logprobs, sampling_metadata, sample_results)
return _build_sampler_output(sample_results,
return _build_sampler_output(logits,
sample_results,
sampling_metadata,
prompt_logprobs,
sample_logprobs,
Expand Down Expand Up @@ -970,6 +971,7 @@ def _modify_greedy_probs_inplace(logprobs: torch.Tensor, probs: torch.Tensor,


def _build_sampler_output(
logits,
sample_results: SampleResultType,
sampling_metadata: SamplingMetadata,
prompt_logprobs: List[Optional[PromptLogprobs]],
Expand All @@ -986,29 +988,20 @@ def _build_sampler_output(
speculative decoding rejection sampling.
"""

# If not specified, store None values in SamplerOutput.
if on_device_tensors is not None:
(sampled_token_probs, logprobs_tensor, sampled_token_ids) = on_device_tensors
else:
sampled_token_probs, logprobs_tensor, sampled_token_ids = (None, None, None)

sampler_output = []
for (seq_group, sample_result, group_prompt_logprobs,
group_sample_logprobs) in zip(sampling_metadata.seq_groups,
sample_results, prompt_logprobs,
sample_logprobs):
for (seq_group, sample_result, group_prompt_logprobs, group_sample_logprobs) in zip(sampling_metadata.seq_groups, sample_results, prompt_logprobs, sample_logprobs):
seq_ids = seq_group.seq_ids
next_token_ids, parent_ids = sample_result
seq_outputs = []
for parent_id, next_token_id, logprobs in zip(parent_ids,
next_token_ids,
group_sample_logprobs):
seq_outputs.append(
SequenceOutput(seq_ids[parent_id], next_token_id, logprobs))
sampler_output.append(
SequenceGroupOutput(seq_outputs, group_prompt_logprobs))

# If not specified, store None values in SamplerOutput.
if on_device_tensors is not None:
(sampled_token_probs, logprobs_tensor,
sampled_token_ids) = on_device_tensors
else:
sampled_token_probs, logprobs_tensor, sampled_token_ids = (None, None,
None)
for parent_id, next_token_id, logprobs, sample_idx in zip(parent_ids, next_token_ids, group_sample_logprobs, seq_group.sample_indices):
seq_outputs.append(SequenceOutput(seq_ids[parent_id], next_token_id, logprobs, logits[sample_idx]))
sampler_output.append(SequenceGroupOutput(seq_outputs, group_prompt_logprobs))

return SamplerOutput(
outputs=sampler_output,
Expand Down
20 changes: 13 additions & 7 deletions vllm/outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def __init__(
text: str,
token_ids: List[int],
cumulative_logprob: float,
logits: List[float],
logprobs: Optional[SampleLogprobs],
finish_reason: Optional[str] = None,
stop_reason: Union[int, str, None] = None,
Expand All @@ -43,6 +44,7 @@ def __init__(
self.finish_reason = finish_reason
self.stop_reason = stop_reason
self.lora_request = lora_request
self.logits = logits

def finished(self) -> bool:
return self.finish_reason is not None
Expand All @@ -52,6 +54,7 @@ def __repr__(self) -> str:
f"text={self.text!r}, "
f"token_ids={self.token_ids}, "
f"cumulative_logprob={self.cumulative_logprob}, "
f"logits={self.logits[0][:3]} ... {self.logits[0][-3:]}, "
f"logprobs={self.logprobs}, "
f"finish_reason={self.finish_reason}, "
f"stop_reason={self.stop_reason})")
Expand Down Expand Up @@ -114,13 +117,16 @@ def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput":
include_logprobs = seq_group.sampling_params.logprobs is not None
text_buffer_length = seq_group.sampling_params.output_text_buffer_length
outputs = [
CompletionOutput(seqs.index(seq),
seq.get_output_text_to_return(text_buffer_length),
seq.get_output_token_ids(),
seq.get_cumulative_logprob(),
seq.output_logprobs if include_logprobs else None,
SequenceStatus.get_finished_reason(seq.status),
seq.stop_reason) for seq in top_n_seqs
CompletionOutput(
seqs.index(seq),
seq.get_output_text_to_return(text_buffer_length),
seq.get_output_token_ids(),
seq.get_cumulative_logprob(),
seq.get_output_logits(),
seq.output_logprobs if include_logprobs else None,
SequenceStatus.get_finished_reason(seq.status),
seq.stop_reason,
) for seq in top_n_seqs
]

# Every sequence in the sequence group should have the same prompt.
Expand Down
21 changes: 18 additions & 3 deletions vllm/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,20 +113,23 @@ class SequenceData:
def __init__(
self,
prompt_token_ids: List[int],
output_logits: Optional[List[float]] = None,
output_token_ids: Optional[List[int]] = None,
) -> None:
if output_token_ids is None:
output_token_ids = []

self.prompt_token_ids = prompt_token_ids
self.output_token_ids = output_token_ids
self.output_logits = output_logits or []
self.cumulative_logprob = 0.0
# The number of tokens that are computed (that run against the model).
self._num_computed_tokens = 0
self._stage: SequenceStage = SequenceStage.PREFILL

def append_token_id(self, token_id: int, logprob: float) -> None:
def append_token_id(self, token_id: int, logprob: float, logits: List[float]) -> None:
self.output_token_ids.append(token_id)
self.output_logits.append(logits)
self.cumulative_logprob += logprob

def get_len(self) -> int:
Expand All @@ -135,6 +138,9 @@ def get_len(self) -> int:
def get_prompt_len(self) -> int:
return len(self.prompt_token_ids)

def get_output_logits(self) -> List[float]:
return self.output_logits

def get_output_len(self) -> int:
return len(self.output_token_ids)

Expand Down Expand Up @@ -219,6 +225,7 @@ def __init__(
self.lora_request = lora_request

self.data: SequenceData = SequenceData(prompt_token_ids)
self.output_logits = []
self.output_logprobs: SampleLogprobs = []
self.output_text = ""

Expand Down Expand Up @@ -288,11 +295,13 @@ def append_token_id(
self,
token_id: int,
logprobs: Dict[int, Logprob],
logits: List[float],
) -> None:
assert token_id in logprobs
self._append_tokens_to_blocks([token_id])
self.output_logprobs.append(logprobs)
self.data.append_token_id(token_id, logprobs[token_id].logprob)
self.output_logits.append(logits)
self.data.append_token_id(token_id, logprobs[token_id].logprob, logits)

def get_len(self) -> int:
return self.data.get_len()
Expand All @@ -303,6 +312,9 @@ def get_prompt_len(self) -> int:
def get_output_len(self) -> int:
return self.data.get_output_len()

def get_output_logits(self) -> List[float]:
return self.data.get_output_logits()

def get_token_ids(self) -> List[int]:
return self.data.get_token_ids()

Expand Down Expand Up @@ -644,13 +656,16 @@ def __init__(
parent_seq_id: int,
output_token: int,
logprobs: Dict[int, Logprob],
logits
) -> None:
self.parent_seq_id = parent_seq_id
self.output_logits = logits.tolist()
self.output_token = output_token
self.logprobs = logprobs

def __repr__(self) -> str:
return (f"SequenceOutput(parent_seq_id={self.parent_seq_id}, "
f"output_logits_len={len(self.output_logits)}, "
f"output_token={self.output_token}, "
f"logprobs={self.logprobs})")

Expand All @@ -677,7 +692,7 @@ def __init__(

def __repr__(self) -> str:
return (f"SequenceGroupOutput(samples={self.samples}, "
f"prompt_logprobs={self.prompt_logprobs})")
f"prompt_logprobs={self.prompt_logprobs}),")

def __eq__(self, other: object) -> bool:
if not isinstance(other, SequenceGroupOutput):
Expand Down

0 comments on commit 069808b

Please sign in to comment.