Skip to content

Commit

Permalink
Handle the punctuation definition mismatch between different Unicode …
Browse files Browse the repository at this point in the history
…versions.

PiperOrigin-RevId: 707239296
  • Loading branch information
tf-text-github-robot committed Dec 18, 2024
1 parent 31f22e9 commit e7d896c
Show file tree
Hide file tree
Showing 5 changed files with 85 additions and 5 deletions.
3 changes: 3 additions & 0 deletions tensorflow_text/core/kernels/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -403,12 +403,15 @@ cc_test(
srcs = ["fast_wordpiece_tokenizer_test.cc"],
data = [
"//tensorflow_text:python/ops/test_data/fast_wordpiece_tokenizer_model.fb",
"//tensorflow_text:python/ops/test_data/fast_wordpiece_tokenizer_model_ver_15_1.fb",
"//tensorflow_text:python/ops/test_data/fast_wordpiece_tokenizer_model_ver_16_0.fb",
],
deps = [
":fast_wordpiece_tokenizer",
":fast_wordpiece_tokenizer_model_builder",
"@com_google_googletest//:gtest_main",
"@com_google_absl//absl/flags:flag",
"//third_party/icu:headers",
# tf:lib tensorflow dep,
],
)
Expand Down
20 changes: 15 additions & 5 deletions tensorflow_text/core/kernels/fast_wordpiece_tokenizer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -278,14 +278,24 @@ void FastWordpieceTokenizer::TokenizeTextImpl(
prev_unicode_char))) {
// If the current Unicode character is a valid word boundary, collect the
// remaining tokens stored on a path on the trie.
absl::string_view cur_str = absl::string_view(
input_substr.data(), cur_pos - input_word_offset_in_text);
HandleTheRemainingStringOnTriePath<kGetPieces, kGetIds, kGetOffsets>(
absl::string_view(input_substr.data(),
cur_pos - input_word_offset_in_text),
input_word_offset_in_text, cur_node, original_num_tokens,
cur_str, input_word_offset_in_text, cur_node, original_num_tokens,
cur_offset_in_input_word, output_pieces, output_ids,
output_start_offsets, output_end_offsets);
// Skip the whitespace.
if (is_white_space) cur_pos = next_pos;
if (is_white_space) {
// Skip the whitespace.
cur_pos = next_pos;
} else if (cur_str.empty()) {
// If the remaining tokens are empty, it means we encountered an
// unmappable separator, so output an unknown token and continue.
cur_pos = next_pos;
ResetOutputAppendUnknownToken<kGetPieces, kGetIds, kGetOffsets>(
input_word_offset_in_text, (cur_pos - input_word_offset_in_text),
original_num_tokens, output_pieces, output_ids,
output_start_offsets, output_end_offsets);
}
// Continue in the outer while loop to process the remaining input.
continue;
}
Expand Down
67 changes: 67 additions & 0 deletions tensorflow_text/core/kernels/fast_wordpiece_tokenizer_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,15 @@
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "absl/flags/flag.h"
#include "icu4c/source/common/unicode/uchar.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow_text/core/kernels/fast_wordpiece_tokenizer_model_builder.h"

namespace tensorflow {
namespace text {
namespace {

using ::testing::AnyOf;
using ::testing::ElementsAre;

constexpr char kTestConfigPath[] =
Expand Down Expand Up @@ -58,6 +60,71 @@ TEST(FastWordpieceTokenizerTest, LoadAndTokenize) {
EXPECT_THAT(output_end_offsets, ElementsAre(3, 5, 6, 9));
}

using TestPunctuationVersionMismatch = testing::TestWithParam<std::string>;

TEST_P(TestPunctuationVersionMismatch, Test) {
// The config_flatbuffer used here is built from the following config:
// * vocab = {"a", "abc", "abcdefghi", "##de", "##defgxy", "##deh", "##f",
// "##ghz", "<unk>"}
// * unk_token = "<unk>"
// * suffix_indicator = "##"
// * max_bytes_per_token = 100
// * end_to_end = True

const std::string kTestConfigUnicodePath = GetParam();

// We test the new punctuation symbol: \341\255\277, which was available in
// Unicode 16: https://www.fileformat.info/info/unicode/char//1b7f/index.htm,
// but not in 15.1.
// We also test an existing punctuation symbol ">".
std::string input = "abc>abc\341\255\277abc";

std::string config_flatbuffer;
auto status = tensorflow::ReadFileToString(
tensorflow::Env::Default(), kTestConfigUnicodePath, &config_flatbuffer);
ASSERT_TRUE(status.ok());

ASSERT_OK_AND_ASSIGN(
auto tokenizer, FastWordpieceTokenizer::Create(config_flatbuffer.data()));

std::vector<std::string> output_tokens;
std::vector<int> output_ids;
std::vector<int> output_start_offsets;
std::vector<int> output_end_offsets;
tokenizer.Tokenize(input, &output_tokens, &output_ids, &output_start_offsets,
&output_end_offsets);

// If the runtime environment has unicode <=15.1, "\341\255\277" is not a
// punctuation, so "abc\341\255\277abc" is one token.
// If the runtime environment has unicode >=16.0, "\341\255\277" is a
// punctuation, so tokens are "abc", "<unk>", "abc"
EXPECT_THAT(output_tokens.size(), AnyOf(3, 5));
if (!u_ispunct(0x1b7f)) {
// We have a runtime environment of unicode <= 15.1.
EXPECT_THAT(output_tokens, ElementsAre("abc", "<unk>", "<unk>"));
EXPECT_THAT(output_ids, ElementsAre(1, 8, 8));
EXPECT_THAT(output_start_offsets, ElementsAre(0, 3, 4));
EXPECT_THAT(output_end_offsets, ElementsAre(3, 4, 13));
} else {
// We have a runtime environment of unicode >= 16.0.
EXPECT_THAT(output_tokens,
ElementsAre("abc", "<unk>", "abc", "<unk>", "abc"));
EXPECT_THAT(output_ids, ElementsAre(1, 8, 1, 8, 1));
EXPECT_THAT(output_start_offsets, ElementsAre(0, 3, 4, 7, 10));
EXPECT_THAT(output_end_offsets, ElementsAre(3, 4, 7, 10, 13));
}
}

INSTANTIATE_TEST_SUITE_P(FastWordpieceTokenizerPunctuationTest,
TestPunctuationVersionMismatch,
testing::Values(
// Unicode v 15.1 config
"third_party/tensorflow_text/python/ops/test_data/"
"fast_wordpiece_tokenizer_model_ver_15_1.fb",
// Unicode v 16.0 config
"third_party/tensorflow_text/python/ops/test_data/"
"fast_wordpiece_tokenizer_model_ver_16_0.fb"));

template <typename T>
std::string ListToString(const std::vector<T>& list) {
return absl::StrCat("[", absl::StrJoin(list, ", "), "]");
Expand Down
Binary file not shown.
Binary file not shown.

0 comments on commit e7d896c

Please sign in to comment.