From da653c0746cbdd705c5d7b90a41bf12546779f99 Mon Sep 17 00:00:00 2001 From: Alexandros Koumparoulis Date: Mon, 6 Jan 2025 14:29:28 -0800 Subject: [PATCH] add test Signed-off-by: Alexandros Koumparoulis --- .../nlp/test_tokenizer_with_special_tokens.py | 51 +++++++++++++++++++ 1 file changed, 51 insertions(+) create mode 100644 tests/collections/nlp/test_tokenizer_with_special_tokens.py diff --git a/tests/collections/nlp/test_tokenizer_with_special_tokens.py b/tests/collections/nlp/test_tokenizer_with_special_tokens.py new file mode 100644 index 000000000000..3a07a13199c0 --- /dev/null +++ b/tests/collections/nlp/test_tokenizer_with_special_tokens.py @@ -0,0 +1,51 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import json +import os +import random +from functools import partial + +import pytest + +from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer + +TOKENIZER_SPM_FILE = '/home/TestData/nlp/tokenizer_with_special_tokens/tokenizer.model' + +def test_spm_with_special_tokens() -> None: + special_tokens = ['[INST]', '[/INST]', '[TOOL_CALLS]', '[AVAILABLE_TOOLS]', '[/AVAILABLE_TOOLS]', '[TOOL_RESULTS]', '[/TOOL_RESULTS]'] + tokenizer_cfg = { + "library": "sentencepiece", + "type": None, + "vocab_file": None, + "merge_file": None, + "delimiter": None, + "sentencepiece_legacy": True, + "special_tokens": special_tokens, + } + tokenizer = get_nmt_tokenizer( + library=tokenizer_cfg['library'], + model_name=tokenizer_cfg.get("type", None), + use_fast=tokenizer_cfg.get("use_fast", False), + delimiter=tokenizer_cfg.get("delimiter", None), + special_tokens=tokenizer_cfg.get("special_tokens", None), + trust_remote_code=tokenizer_cfg.get("trust_remote_code", False), + tokenizer_model=TOKENIZER_SPM_FILE, + legacy=True, + ) + + assert tokenizer.text_to_ids('[INST]') == [3] + for special_token in special_tokens: + assert special_token in tokenizer.special_token_to_id, f'Expected {special_token} to be a special token' \ No newline at end of file