Skip to content

Commit

Permalink
fix _seen_tokens
Browse files Browse the repository at this point in the history
  • Loading branch information
xtinkt committed Jun 16, 2024
1 parent 068d934 commit 8539a08
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
6 changes: 3 additions & 3 deletions src/petals/client/remote_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,20 +22,20 @@ class RemotePastKeyValues(Cache):

def __init__(self) -> None:
super().__init__()
self.seen_tokens = 0
self._seen_tokens = 0
self.hypo_ids: Optional[torch.LongTensor] = None

def __getitem__(self, _index: int) -> List[torch.Tensor]:
return [DUMMY] # For compatibility with BloomForCausalLM.prepare_inputs_for_generation()

def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
return self.seen_tokens
return self._seen_tokens

def get_max_length(self) -> Optional[int]:
return None

def update_seen(self, new_seen: int) -> None:
self.seen_tokens += new_seen
self._seen_tokens += new_seen

def reorder_cache(self, beam_idx):
raise NotImplementedError("Beam search reordering is not implemented yet")
Expand Down
2 changes: 1 addition & 1 deletion src/petals/models/bloom/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def prepare_inputs_for_generation(
if past_key_values is not None:
if isinstance(past_key_values, Cache):
cache_length = past_key_values.get_seq_length()
past_length = past_key_values.seen_tokens
past_length = past_key_values._seen_tokens
max_cache_length = past_key_values.get_max_length()
else:
cache_length = past_length = past_key_values[0][0].shape[2]
Expand Down

0 comments on commit 8539a08

Please sign in to comment.