Skip to content

Commit

Permalink
Apply isort and black reformatting
Browse files Browse the repository at this point in the history
Signed-off-by: akoumpa <[email protected]>
  • Loading branch information
akoumpa committed Jan 10, 2025
1 parent e4251fc commit f12af7d
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 8 deletions.
17 changes: 10 additions & 7 deletions nemo/collections/common/tokenizers/sentencepiece_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ def __init__(
legacy: bool = False,
ignore_extra_whitespaces: bool = True,
chat_template: Optional[Dict] = None,
trim_spm_separator_after_special_token = True,
spm_separator = '▁',
trim_spm_separator_after_special_token=True,
spm_separator='▁',
):
self.chat_template = chat_template
if not model_path or not os.path.exists(model_path):
Expand Down Expand Up @@ -148,16 +148,19 @@ def _text_to_ids(self, text, sample_alpha=None):
next_token = min(indices, key=indices.get)
next_idx = idx + indices[next_token]


te = self.tokenizer.encode(text[idx:next_idx], out_type=str)
text_tokens = self.tokenizer.encode(text[idx:next_idx])
# Chat-templates insert a space between a special token and first word (e.g.
# "[INST] who") which is tokenized as <inst-id> <space-id> <who-id> instead of
# <inst-id> <who-id>.
if self.trim_spm_separator_after_special_token and len(ids) > 0 \
and ids[-1] in self.id_to_special_token \
and len(text_tokens) > 0 and text_tokens[0] == self.spm_separator_id:
text_tokens.pop(0)
if (
self.trim_spm_separator_after_special_token
and len(ids) > 0
and ids[-1] in self.id_to_special_token
and len(text_tokens) > 0
and text_tokens[0] == self.spm_separator_id
):
text_tokens.pop(0)
ids.extend(text_tokens)
ids.append(self.special_token_to_id[next_token])
idx = next_idx + len(next_token)
Expand Down
1 change: 1 addition & 0 deletions nemo/collections/nlp/modules/common/tokenizer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ def get_nmt_tokenizer(
"""
import omegaconf
from omegaconf import OmegaConf

if isinstance(special_tokens, omegaconf.listconfig.ListConfig):
special_tokens = OmegaConf.to_container(special_tokens)
if special_tokens is None:
Expand Down
15 changes: 14 additions & 1 deletion tests/collections/nlp/test_tokenizer_with_special_tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,18 @@
from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer

TOKENIZER_SPM_FILE = '/home/TestData/nlp/tokenizer_with_special_tokens/tokenizer.model'
SPECIAL_TOKENS = ['<s>', '</s>', '[INST]', '[/INST]', '[TOOL_CALLS]', '[AVAILABLE_TOOLS]', '[/AVAILABLE_TOOLS]', '[TOOL_RESULTS]', '[/TOOL_RESULTS]']
SPECIAL_TOKENS = [
'<s>',
'</s>',
'[INST]',
'[/INST]',
'[TOOL_CALLS]',
'[AVAILABLE_TOOLS]',
'[/AVAILABLE_TOOLS]',
'[TOOL_RESULTS]',
'[/TOOL_RESULTS]',
]


def _build_tokenizer(spm_file, special_tokens):
tokenizer_cfg = {
Expand All @@ -39,13 +50,15 @@ def _build_tokenizer(spm_file, special_tokens):
legacy=True,
)


def test_spm_with_special_tokens() -> None:
tokenizer = _build_tokenizer(TOKENIZER_SPM_FILE, SPECIAL_TOKENS)
assert tokenizer.text_to_ids('[INST]') == [3]
for i, special_token in enumerate(SPECIAL_TOKENS):
assert special_token in tokenizer.special_token_to_id, f'Expected {special_token} to be a special token'
assert tokenizer.special_token_to_id[special_token] == i + 1


def test_trim_spm_separator_after_special_token():
tokenizer = _build_tokenizer(TOKENIZER_SPM_FILE, SPECIAL_TOKENS)
tokenizer.text_to_ids('<s>[INST] Who') == [1, 3, 7294]
Expand Down

0 comments on commit f12af7d

Please sign in to comment.