Skip to content

Commit

Permalink
add doc for Python CustomStreamer with buffer
Browse files Browse the repository at this point in the history
  • Loading branch information
pavel-esir committed Aug 2, 2024
1 parent a295fe1 commit 8a38316
Showing 1 changed file with 64 additions and 0 deletions.
64 changes: 64 additions & 0 deletions src/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,70 @@ int main(int argc, char* argv[]) {
}
```
This Python example demonstrates custom detokenization with buferisation. The streamer receives
integer tokens corresponding to each word or subword, one by one. If tokens are decoded individually,
subwords will not be concatenated correctly, and the resulting text will lack appropriate spaces.
To address this, we accumulate tokens in a tokens_cache buffer and decode multiple tokens together,
returning the text only when a complete decoded chunk is ready.
```py
import openvino_genai as ov_genai
class TextPrintStreamer(ov_genai.StreamerBase):
def __init__(self, tokenizer):
super().__init__()
self.tokenizer = tokenizer
self.tokens_cache = []
self.print_len = 0
def get_stop_flag(self):
return False
def process_word(self, word: str):
print(word, end='', flush=True)
def put(self, token_id):
self.tokens_cache.append(token_id)
text = self.tokenizer.decode(self.tokens_cache)
word = ''
if len(text) > self.print_len and '\n' == text[-1]:
# Flush the cache after the new line symbol.
word = text[self.print_len:]
self.tokens_cache = []
self.print_len = 0
elif len(text) >= 3 and text[-3:] == "�":
# Don't print incomplete text.
pass
elif len(text) > self.print_len:
# It is possible to have a shorter text after adding new token.
# Print to output only if text lengh is increaesed.
word = text[self.print_len:]
self.print_len = len(text)
self.process_word(word)
if self.get_stop_flag():
# When generation is stopped from streamer then end is not called, need to call it here manually.
self.end()
return True # True means stop generation
else:
return False # False means continue generation
def end(self):
# Flush residual tokens from the buffer.
text = self.tokenizer.decode(self.tokens_cache)
if len(text) > self.print_len:
word = text[self.print_len:]
self.process_word(word)
self.tokens_cache = []
self.print_len = 0
pipe = ov_genai.LLMPipeline(model_path, "CPU")
text_print_streamer = TextPrintStreamer(pipe.get_tokenizer())
pipe.generate("The Sun is yellow because", max_new_tokens=15, streamer=text_print_streamer)
```

### Performance Metrics

`openvino_genai.PerfMetrics` (referred as `PerfMetrics` for simplicity) is a structure that holds performance metrics for each generate call. `PerfMetrics` holds fields with mean and standard deviations for the following metrics:
Expand Down

0 comments on commit 8a38316

Please sign in to comment.