Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix sentencepiece tokenizer special tokens #11767

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions nemo/collections/common/tokenizers/sentencepiece_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,9 @@ def add_special_tokens(self, special_tokens):
self.special_token_to_id[token] = self.vocab_size
self.id_to_special_token[self.vocab_size] = token
self.vocab_size += 1
elif self.tokenizer.piece_to_id(token) != self.tokenizer.unk_id():
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if a token is already in the vocab, we want to allow users to mark them as special tokens.

self.special_token_to_id[token] = self.tokenizer.piece_to_id(token)

elif isinstance(special_tokens, dict):
for token_name, token in special_tokens.items():
setattr(self, token_name, token)
Expand All @@ -247,6 +250,8 @@ def add_special_tokens(self, special_tokens):
self.special_token_to_id[token] = self.vocab_size
self.id_to_special_token[self.vocab_size] = token
self.vocab_size += 1
else:
raise ValueError("Expected special_tokens to be a list or a dict " + str(type(special_tokens)))

@property
def pad_id(self):
Expand Down
7 changes: 7 additions & 0 deletions nemo/collections/nlp/modules/common/tokenizer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,11 @@ def get_nmt_tokenizer(
It has empirically been shown to improve inference time BLEU scores.
r2l: Whether to return subword IDs from right to left
"""
import omegaconf
from omegaconf import OmegaConf

if isinstance(special_tokens, omegaconf.listconfig.ListConfig):
special_tokens = OmegaConf.to_container(special_tokens)
if special_tokens is None:
special_tokens_dict = {}
else:
Expand All @@ -195,8 +200,10 @@ def get_nmt_tokenizer(
from nemo.collections.common.tokenizers.sentencepiece_tokenizer import SentencePieceTokenizer

logging.info(f'Getting SentencePiece with model: {tokenizer_model}')

return SentencePieceTokenizer(
model_path=tokenizer_model,
special_tokens=special_tokens,
legacy=legacy,
chat_template=chat_template,
)
Expand Down
56 changes: 56 additions & 0 deletions tests/collections/nlp/test_tokenizer_with_special_tokens.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer

TOKENIZER_SPM_FILE = '/home/TestData/nlp/tokenizer_with_special_tokens/tokenizer.model'


def test_spm_with_special_tokens() -> None:
special_tokens = [
'<s>',
'</s>',
'[INST]',
'[/INST]',
'[TOOL_CALLS]',
'[AVAILABLE_TOOLS]',
'[/AVAILABLE_TOOLS]',
'[TOOL_RESULTS]',
'[/TOOL_RESULTS]',
]
tokenizer_cfg = {
"library": "sentencepiece",
"type": None,
"vocab_file": None,
"merge_file": None,
"delimiter": None,
"sentencepiece_legacy": True,
"special_tokens": special_tokens,
}
tokenizer = get_nmt_tokenizer(
library=tokenizer_cfg['library'],
model_name=tokenizer_cfg.get("type", None),
use_fast=tokenizer_cfg.get("use_fast", False),
delimiter=tokenizer_cfg.get("delimiter", None),
special_tokens=tokenizer_cfg.get("special_tokens", None),
trust_remote_code=tokenizer_cfg.get("trust_remote_code", False),
tokenizer_model=TOKENIZER_SPM_FILE,
legacy=True,
)

assert tokenizer.text_to_ids('[INST]') == [3]
for i, special_token in enumerate(special_tokens):
assert special_token in tokenizer.special_token_to_id, f'Expected {special_token} to be a special token'
assert tokenizer.special_token_to_id[special_token] == i + 1
Loading