Skip to content

Commit

Permalink
add test
Browse files Browse the repository at this point in the history
Signed-off-by: Alexandros Koumparoulis <[email protected]>
  • Loading branch information
akoumpa committed Jan 10, 2025
1 parent c35d2cb commit e4251fc
Showing 1 changed file with 14 additions and 6 deletions.
20 changes: 14 additions & 6 deletions tests/collections/nlp/test_tokenizer_with_special_tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@
from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer

TOKENIZER_SPM_FILE = '/home/TestData/nlp/tokenizer_with_special_tokens/tokenizer.model'
SPECIAL_TOKENS = ['<s>', '</s>', '[INST]', '[/INST]', '[TOOL_CALLS]', '[AVAILABLE_TOOLS]', '[/AVAILABLE_TOOLS]', '[TOOL_RESULTS]', '[/TOOL_RESULTS]']

def test_spm_with_special_tokens() -> None:
special_tokens = ['<s>', '</s>', '[INST]', '[/INST]', '[TOOL_CALLS]', '[AVAILABLE_TOOLS]', '[/AVAILABLE_TOOLS]', '[TOOL_RESULTS]', '[/TOOL_RESULTS]']
def _build_tokenizer(spm_file, special_tokens):
tokenizer_cfg = {
"library": "sentencepiece",
"type": None,
Expand All @@ -28,18 +28,26 @@ def test_spm_with_special_tokens() -> None:
"sentencepiece_legacy": True,
"special_tokens": special_tokens,
}
tokenizer = get_nmt_tokenizer(
return get_nmt_tokenizer(
library=tokenizer_cfg['library'],
model_name=tokenizer_cfg.get("type", None),
use_fast=tokenizer_cfg.get("use_fast", False),
delimiter=tokenizer_cfg.get("delimiter", None),
special_tokens=tokenizer_cfg.get("special_tokens", None),
trust_remote_code=tokenizer_cfg.get("trust_remote_code", False),
tokenizer_model=TOKENIZER_SPM_FILE,
tokenizer_model=spm_file,
legacy=True,
)

def test_spm_with_special_tokens() -> None:
tokenizer = _build_tokenizer(TOKENIZER_SPM_FILE, SPECIAL_TOKENS)
assert tokenizer.text_to_ids('[INST]') == [3]
for i, special_token in enumerate(special_tokens):
for i, special_token in enumerate(SPECIAL_TOKENS):
assert special_token in tokenizer.special_token_to_id, f'Expected {special_token} to be a special token'
assert tokenizer.special_token_to_id[special_token] == i + 1
assert tokenizer.special_token_to_id[special_token] == i + 1

def test_trim_spm_separator_after_special_token():
tokenizer = _build_tokenizer(TOKENIZER_SPM_FILE, SPECIAL_TOKENS)
tokenizer.text_to_ids('<s>[INST] Who') == [1, 3, 7294]
tokenizer.trim_spm_separator_after_special_token = False
tokenizer.text_to_ids('<s>[INST] Who') == [1, 3, 29473, 7294]

0 comments on commit e4251fc

Please sign in to comment.