Skip to content

Commit

Permalink
fix sentencepiece tokenizer special tokens (#11767)
Browse files Browse the repository at this point in the history
* Allow special tokens in vocab

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* pass special tokens to SentencePieceTokenizer

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* add test

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* cleanup

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* cleanup

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* Apply isort and black reformatting

Signed-off-by: akoumpa <[email protected]>

---------

Signed-off-by: Alexandros Koumparoulis <[email protected]>
Signed-off-by: akoumpa <[email protected]>
Co-authored-by: akoumpa <[email protected]>
  • Loading branch information
akoumpa and akoumpa authored Jan 7, 2025
1 parent 5ed118f commit d98b7cd
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 0 deletions.
5 changes: 5 additions & 0 deletions nemo/collections/common/tokenizers/sentencepiece_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,9 @@ def add_special_tokens(self, special_tokens):
self.special_token_to_id[token] = self.vocab_size
self.id_to_special_token[self.vocab_size] = token
self.vocab_size += 1
elif self.tokenizer.piece_to_id(token) != self.tokenizer.unk_id():
self.special_token_to_id[token] = self.tokenizer.piece_to_id(token)

elif isinstance(special_tokens, dict):
for token_name, token in special_tokens.items():
setattr(self, token_name, token)
Expand All @@ -247,6 +250,8 @@ def add_special_tokens(self, special_tokens):
self.special_token_to_id[token] = self.vocab_size
self.id_to_special_token[self.vocab_size] = token
self.vocab_size += 1
else:
raise ValueError("Expected special_tokens to be a list or a dict " + str(type(special_tokens)))

@property
def pad_id(self):
Expand Down
7 changes: 7 additions & 0 deletions nemo/collections/nlp/modules/common/tokenizer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,11 @@ def get_nmt_tokenizer(
It has empirically been shown to improve inference time BLEU scores.
r2l: Whether to return subword IDs from right to left
"""
import omegaconf
from omegaconf import OmegaConf

if isinstance(special_tokens, omegaconf.listconfig.ListConfig):
special_tokens = OmegaConf.to_container(special_tokens)
if special_tokens is None:
special_tokens_dict = {}
else:
Expand All @@ -195,8 +200,10 @@ def get_nmt_tokenizer(
from nemo.collections.common.tokenizers.sentencepiece_tokenizer import SentencePieceTokenizer

logging.info(f'Getting SentencePiece with model: {tokenizer_model}')

return SentencePieceTokenizer(
model_path=tokenizer_model,
special_tokens=special_tokens,
legacy=legacy,
chat_template=chat_template,
)
Expand Down
56 changes: 56 additions & 0 deletions tests/collections/nlp/test_tokenizer_with_special_tokens.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer

TOKENIZER_SPM_FILE = '/home/TestData/nlp/tokenizer_with_special_tokens/tokenizer.model'


def test_spm_with_special_tokens() -> None:
special_tokens = [
'<s>',
'</s>',
'[INST]',
'[/INST]',
'[TOOL_CALLS]',
'[AVAILABLE_TOOLS]',
'[/AVAILABLE_TOOLS]',
'[TOOL_RESULTS]',
'[/TOOL_RESULTS]',
]
tokenizer_cfg = {
"library": "sentencepiece",
"type": None,
"vocab_file": None,
"merge_file": None,
"delimiter": None,
"sentencepiece_legacy": True,
"special_tokens": special_tokens,
}
tokenizer = get_nmt_tokenizer(
library=tokenizer_cfg['library'],
model_name=tokenizer_cfg.get("type", None),
use_fast=tokenizer_cfg.get("use_fast", False),
delimiter=tokenizer_cfg.get("delimiter", None),
special_tokens=tokenizer_cfg.get("special_tokens", None),
trust_remote_code=tokenizer_cfg.get("trust_remote_code", False),
tokenizer_model=TOKENIZER_SPM_FILE,
legacy=True,
)

assert tokenizer.text_to_ids('[INST]') == [3]
for i, special_token in enumerate(special_tokens):
assert special_token in tokenizer.special_token_to_id, f'Expected {special_token} to be a special token'
assert tokenizer.special_token_to_id[special_token] == i + 1

0 comments on commit d98b7cd

Please sign in to comment.