diff --git a/vllm/engine/output_processor/single_step.py b/vllm/engine/output_processor/single_step.py index d29a6880ff20e..e8b08d8c2ee88 100644 --- a/vllm/engine/output_processor/single_step.py +++ b/vllm/engine/output_processor/single_step.py @@ -103,7 +103,7 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup, child = parent.fork(new_child_seq_id) child.append_token_id(child_sample.output_token, child_sample.logprobs, - child_sample.output_logits) + child_sample.output_classification_probs) child_seqs.append((child, parent)) # Continue the parent sequence for the last child sample. # We reuse the parent sequence here to reduce redundant memory @@ -111,7 +111,7 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup, last_child_sample = child_samples[-1] parent.append_token_id(last_child_sample.output_token, last_child_sample.logprobs, - last_child_sample.output_logits) + last_child_sample.output_classification_probs) child_seqs.append((parent, parent)) for seq, _ in child_seqs: diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index 912d212a44349..c5e1ae88ccb18 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -46,6 +46,9 @@ def __init__(self): # speculative decoding. self.include_gpu_probs_tensor = False + self.classification_head = torch.nn.Linear(1, 1, bias=False).to("cuda") + self.classification_head.weight.data = torch.load("classification_head.pth", map_location="cuda").bfloat16() + def forward( self, logits: torch.Tensor, @@ -61,6 +64,10 @@ def forward( logits = _apply_min_tokens_penalty(logits, sampling_metadata) + classification_probs = torch.nn.functional.sigmoid( + self.classification_head(logits) + ).flatten().tolist() + # Prepare sampling tensors with pinned memory to avoid blocking. (sampling_tensors, do_penalties, do_top_p_top_k, do_min_p) = SamplingTensors.from_sampling_metadata( @@ -110,7 +117,7 @@ def forward( # Get the logprobs query results. prompt_logprobs, sample_logprobs = _get_logprobs( logprobs, sampling_metadata, sample_results) - return _build_sampler_output(logits, + return _build_sampler_output(classification_probs, sample_results, sampling_metadata, prompt_logprobs, @@ -971,7 +978,7 @@ def _modify_greedy_probs_inplace(logprobs: torch.Tensor, probs: torch.Tensor, def _build_sampler_output( - logits, + classification_probs, sample_results: SampleResultType, sampling_metadata: SamplingMetadata, prompt_logprobs: List[Optional[PromptLogprobs]], @@ -1000,7 +1007,7 @@ def _build_sampler_output( next_token_ids, parent_ids = sample_result seq_outputs = [] for parent_id, next_token_id, logprobs, sample_idx in zip(parent_ids, next_token_ids, group_sample_logprobs, seq_group.sample_indices): - seq_outputs.append(SequenceOutput(seq_ids[parent_id], next_token_id, logprobs, logits[sample_idx])) + seq_outputs.append(SequenceOutput(seq_ids[parent_id], next_token_id, logprobs, classification_probs[sample_idx])) sampler_output.append(SequenceGroupOutput(seq_outputs, group_prompt_logprobs)) return SamplerOutput( diff --git a/vllm/outputs.py b/vllm/outputs.py index 725d90cc72222..2878abec68f3a 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -30,7 +30,7 @@ def __init__( text: str, token_ids: List[int], cumulative_logprob: float, - logits: List[float], + classification_probs: List[float], logprobs: Optional[SampleLogprobs], finish_reason: Optional[str] = None, stop_reason: Union[int, str, None] = None, @@ -44,7 +44,7 @@ def __init__( self.finish_reason = finish_reason self.stop_reason = stop_reason self.lora_request = lora_request - self.logits = logits + self.classification_probs = classification_probs def finished(self) -> bool: return self.finish_reason is not None @@ -54,7 +54,7 @@ def __repr__(self) -> str: f"text={self.text!r}, " f"token_ids={self.token_ids}, " f"cumulative_logprob={self.cumulative_logprob}, " - f"logits={self.logits[0][:3]} ... {self.logits[0][-3:]}, " + f"classification_probs={self.classification_probs}, " f"logprobs={self.logprobs}, " f"finish_reason={self.finish_reason}, " f"stop_reason={self.stop_reason})") @@ -122,7 +122,7 @@ def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput": seq.get_output_text_to_return(text_buffer_length), seq.get_output_token_ids(), seq.get_cumulative_logprob(), - seq.get_output_logits(), + seq.get_output_classification_probs(), seq.output_logprobs if include_logprobs else None, SequenceStatus.get_finished_reason(seq.status), seq.stop_reason, diff --git a/vllm/sequence.py b/vllm/sequence.py index 2ffe0ffe4f1f3..5b74461bcc300 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -113,7 +113,7 @@ class SequenceData: def __init__( self, prompt_token_ids: List[int], - output_logits: Optional[List[float]] = None, + output_classification_probs: Optional[List[float]] = None, output_token_ids: Optional[List[int]] = None, ) -> None: if output_token_ids is None: @@ -121,15 +121,15 @@ def __init__( self.prompt_token_ids = prompt_token_ids self.output_token_ids = output_token_ids - self.output_logits = output_logits or [] + self.output_classification_probs = output_classification_probs or [] self.cumulative_logprob = 0.0 # The number of tokens that are computed (that run against the model). self._num_computed_tokens = 0 self._stage: SequenceStage = SequenceStage.PREFILL - def append_token_id(self, token_id: int, logprob: float, logits: List[float]) -> None: + def append_token_id(self, token_id: int, logprob: float, classification_probs: List[float]) -> None: self.output_token_ids.append(token_id) - self.output_logits.append(logits) + self.output_classification_probs.append(classification_probs) self.cumulative_logprob += logprob def get_len(self) -> int: @@ -138,8 +138,8 @@ def get_len(self) -> int: def get_prompt_len(self) -> int: return len(self.prompt_token_ids) - def get_output_logits(self) -> List[float]: - return self.output_logits + def get_output_classification_probs(self) -> List[float]: + return self.output_classification_probs def get_output_len(self) -> int: return len(self.output_token_ids) @@ -225,7 +225,7 @@ def __init__( self.lora_request = lora_request self.data: SequenceData = SequenceData(prompt_token_ids) - self.output_logits = [] + self.output_classification_probs = [] self.output_logprobs: SampleLogprobs = [] self.output_text = "" @@ -295,13 +295,13 @@ def append_token_id( self, token_id: int, logprobs: Dict[int, Logprob], - logits: List[float], + classification_probs: List[float], ) -> None: assert token_id in logprobs self._append_tokens_to_blocks([token_id]) self.output_logprobs.append(logprobs) - self.output_logits.append(logits) - self.data.append_token_id(token_id, logprobs[token_id].logprob, logits) + self.output_classification_probs.append(classification_probs) + self.data.append_token_id(token_id, logprobs[token_id].logprob, classification_probs) def get_len(self) -> int: return self.data.get_len() @@ -312,8 +312,8 @@ def get_prompt_len(self) -> int: def get_output_len(self) -> int: return self.data.get_output_len() - def get_output_logits(self) -> List[float]: - return self.data.get_output_logits() + def get_output_classification_probs(self) -> List[float]: + return self.data.get_output_classification_probs() def get_token_ids(self) -> List[int]: return self.data.get_token_ids() @@ -656,16 +656,16 @@ def __init__( parent_seq_id: int, output_token: int, logprobs: Dict[int, Logprob], - logits + classification_probs: List[float] ) -> None: self.parent_seq_id = parent_seq_id - self.output_logits = logits.tolist() + self.output_classification_probs = classification_probs self.output_token = output_token self.logprobs = logprobs def __repr__(self) -> str: return (f"SequenceOutput(parent_seq_id={self.parent_seq_id}, " - f"output_logits_len={len(self.output_logits)}, " + f"output_classification_probs={self.output_classification_probs}, " f"output_token={self.output_token}, " f"logprobs={self.logprobs})")