From fc9e81d4ab50f62d72d0db381465600105940743 Mon Sep 17 00:00:00 2001 From: Anthony MOI <m.anthony.moi@gmail.com> Date: Sun, 12 Jan 2020 02:35:45 -0500 Subject: [PATCH] Fix split on special tokens & bump version --- bindings/python/Cargo.lock | 6 ++--- bindings/python/Cargo.toml | 2 +- bindings/python/setup.py | 2 +- bindings/python/tokenizers/__init__.py | 2 +- tokenizers/Cargo.toml | 2 +- tokenizers/src/tokenizer/mod.rs | 32 ++++++++++++++++++++------ 6 files changed, 32 insertions(+), 14 deletions(-) diff --git a/bindings/python/Cargo.lock b/bindings/python/Cargo.lock index b8c7d8f36..3c2f2e56e 100644 --- a/bindings/python/Cargo.lock +++ b/bindings/python/Cargo.lock @@ -555,7 +555,7 @@ dependencies = [ [[package]] name = "tokenizers" -version = "0.6.0" +version = "0.6.1" dependencies = [ "clap 2.33.0 (registry+https://github.com/rust-lang/crates.io-index)", "indicatif 0.13.0 (registry+https://github.com/rust-lang/crates.io-index)", @@ -571,10 +571,10 @@ dependencies = [ [[package]] name = "tokenizers-python" -version = "0.1.0" +version = "0.1.1" dependencies = [ "pyo3 0.8.4 (registry+https://github.com/rust-lang/crates.io-index)", - "tokenizers 0.6.0", + "tokenizers 0.6.1", ] [[package]] diff --git a/bindings/python/Cargo.toml b/bindings/python/Cargo.toml index fbb5785bf..03e583bfd 100644 --- a/bindings/python/Cargo.toml +++ b/bindings/python/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tokenizers-python" -version = "0.1.0" +version = "0.1.1" authors = ["Anthony MOI <m.anthony.moi@gmail.com>"] edition = "2018" diff --git a/bindings/python/setup.py b/bindings/python/setup.py index 1cc963441..b31427486 100644 --- a/bindings/python/setup.py +++ b/bindings/python/setup.py @@ -3,7 +3,7 @@ setup( name="tokenizers", - version="0.1.0", + version="0.1.1", description="Fast and Customizable Tokenizers", long_description=open("README.md", "r", encoding="utf-8").read(), long_description_content_type="text/markdown", diff --git a/bindings/python/tokenizers/__init__.py b/bindings/python/tokenizers/__init__.py index b1f509a9c..a1167f485 100644 --- a/bindings/python/tokenizers/__init__.py +++ b/bindings/python/tokenizers/__init__.py @@ -1,4 +1,4 @@ -__version__ = "0.1.0" +__version__ = "0.1.1" from .tokenizers import Tokenizer, Encoding from .tokenizers import decoders diff --git a/tokenizers/Cargo.toml b/tokenizers/Cargo.toml index 00c5d0161..cca2a70fc 100644 --- a/tokenizers/Cargo.toml +++ b/tokenizers/Cargo.toml @@ -2,7 +2,7 @@ authors = ["Anthony MOI <m.anthony.moi@gmail.com>"] edition = "2018" name = "tokenizers" -version = "0.6.0" +version = "0.6.1" homepage = "https://github.com/huggingface/tokenizers" repository = "https://github.com/huggingface/tokenizers" documentation = "https://docs.rs/tokenizers/" diff --git a/tokenizers/src/tokenizer/mod.rs b/tokenizers/src/tokenizer/mod.rs index 77bbac0c6..a6d2f7927 100644 --- a/tokenizers/src/tokenizer/mod.rs +++ b/tokenizers/src/tokenizer/mod.rs @@ -565,6 +565,8 @@ impl Tokenizer { } } + self.refresh_added_tokens(); + added } @@ -591,11 +593,27 @@ impl Tokenizer { .or_insert_with(|| token.clone()); } + self.refresh_added_tokens(); + + // Return the number of added tokens + tokens.len() - ignored + } + + fn refresh_added_tokens(&mut self) { // We rebuild the regex here everytime on purpose, because the added tokens may // have changed + let special_tokens = self + .special_tokens + .keys() + .map(|t| AddedToken { + content: t.to_owned(), + single_word: true, + }) + .collect::<Vec<_>>(); let added_tokens = self .added_tokens .keys() + .chain(special_tokens.iter()) .map(|token| { if token.single_word { let first_b = token @@ -635,9 +653,6 @@ impl Tokenizer { self.split_re = Some(regex::Regex::new(&format!(r"({})", added_tokens.join("|"))).unwrap()); } - - // Return the number of added tokens - tokens.len() - ignored } /// Split the given sentence on multiple parts, finding the added tokens and their id in the process @@ -677,10 +692,13 @@ impl Tokenizer { .into_iter() .map(|(start, end)| unsafe { let s = sentence.get_unchecked(start..end).to_owned(); - let id = self.added_tokens.get(&AddedToken { - content: s.clone(), - ..Default::default() - }); + let mut id = self.special_tokens.get(&s); + if id.is_none() { + id = self.added_tokens.get(&AddedToken { + content: s.clone(), + ..Default::default() + }); + } (s, id.copied()) }) .collect()