diff --git a/nemo/collections/common/tokenizers/sentencepiece_tokenizer.py b/nemo/collections/common/tokenizers/sentencepiece_tokenizer.py index 56a4b04dfe0f..45fbd4c8b328 100644 --- a/nemo/collections/common/tokenizers/sentencepiece_tokenizer.py +++ b/nemo/collections/common/tokenizers/sentencepiece_tokenizer.py @@ -237,6 +237,9 @@ def add_special_tokens(self, special_tokens): self.special_token_to_id[token] = self.vocab_size self.id_to_special_token[self.vocab_size] = token self.vocab_size += 1 + elif self.tokenizer.piece_to_id(token) != self.tokenizer.unk_id(): + self.special_token_to_id[token] = self.tokenizer.piece_to_id(token) + elif isinstance(special_tokens, dict): for token_name, token in special_tokens.items(): setattr(self, token_name, token) @@ -247,6 +250,8 @@ def add_special_tokens(self, special_tokens): self.special_token_to_id[token] = self.vocab_size self.id_to_special_token[self.vocab_size] = token self.vocab_size += 1 + else: + raise ValueError("Expected special_tokens to be a list or a dict " + str(type(special_tokens))) @property def pad_id(self): diff --git a/nemo/collections/nlp/modules/common/tokenizer_utils.py b/nemo/collections/nlp/modules/common/tokenizer_utils.py index 31bf493ec776..f066ade86811 100644 --- a/nemo/collections/nlp/modules/common/tokenizer_utils.py +++ b/nemo/collections/nlp/modules/common/tokenizer_utils.py @@ -169,6 +169,11 @@ def get_nmt_tokenizer( It has empirically been shown to improve inference time BLEU scores. r2l: Whether to return subword IDs from right to left """ + import omegaconf + from omegaconf import OmegaConf + + if isinstance(special_tokens, omegaconf.listconfig.ListConfig): + special_tokens = OmegaConf.to_container(special_tokens) if special_tokens is None: special_tokens_dict = {} else: @@ -195,8 +200,10 @@ def get_nmt_tokenizer( from nemo.collections.common.tokenizers.sentencepiece_tokenizer import SentencePieceTokenizer logging.info(f'Getting SentencePiece with model: {tokenizer_model}') + return SentencePieceTokenizer( model_path=tokenizer_model, + special_tokens=special_tokens, legacy=legacy, chat_template=chat_template, ) diff --git a/tests/collections/nlp/test_tokenizer_with_special_tokens.py b/tests/collections/nlp/test_tokenizer_with_special_tokens.py new file mode 100644 index 000000000000..d042231f6670 --- /dev/null +++ b/tests/collections/nlp/test_tokenizer_with_special_tokens.py @@ -0,0 +1,56 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer + +TOKENIZER_SPM_FILE = '/home/TestData/nlp/tokenizer_with_special_tokens/tokenizer.model' + + +def test_spm_with_special_tokens() -> None: + special_tokens = [ + '', + '', + '[INST]', + '[/INST]', + '[TOOL_CALLS]', + '[AVAILABLE_TOOLS]', + '[/AVAILABLE_TOOLS]', + '[TOOL_RESULTS]', + '[/TOOL_RESULTS]', + ] + tokenizer_cfg = { + "library": "sentencepiece", + "type": None, + "vocab_file": None, + "merge_file": None, + "delimiter": None, + "sentencepiece_legacy": True, + "special_tokens": special_tokens, + } + tokenizer = get_nmt_tokenizer( + library=tokenizer_cfg['library'], + model_name=tokenizer_cfg.get("type", None), + use_fast=tokenizer_cfg.get("use_fast", False), + delimiter=tokenizer_cfg.get("delimiter", None), + special_tokens=tokenizer_cfg.get("special_tokens", None), + trust_remote_code=tokenizer_cfg.get("trust_remote_code", False), + tokenizer_model=TOKENIZER_SPM_FILE, + legacy=True, + ) + + assert tokenizer.text_to_ids('[INST]') == [3] + for i, special_token in enumerate(special_tokens): + assert special_token in tokenizer.special_token_to_id, f'Expected {special_token} to be a special token' + assert tokenizer.special_token_to_id[special_token] == i + 1