From fc9e81d4ab50f62d72d0db381465600105940743 Mon Sep 17 00:00:00 2001
From: Anthony MOI <m.anthony.moi@gmail.com>
Date: Sun, 12 Jan 2020 02:35:45 -0500
Subject: [PATCH] Fix split on special tokens & bump version

---
 bindings/python/Cargo.lock             |  6 ++---
 bindings/python/Cargo.toml             |  2 +-
 bindings/python/setup.py               |  2 +-
 bindings/python/tokenizers/__init__.py |  2 +-
 tokenizers/Cargo.toml                  |  2 +-
 tokenizers/src/tokenizer/mod.rs        | 32 ++++++++++++++++++++------
 6 files changed, 32 insertions(+), 14 deletions(-)

diff --git a/bindings/python/Cargo.lock b/bindings/python/Cargo.lock
index b8c7d8f36..3c2f2e56e 100644
--- a/bindings/python/Cargo.lock
+++ b/bindings/python/Cargo.lock
@@ -555,7 +555,7 @@ dependencies = [
 
 [[package]]
 name = "tokenizers"
-version = "0.6.0"
+version = "0.6.1"
 dependencies = [
  "clap 2.33.0 (registry+https://github.com/rust-lang/crates.io-index)",
  "indicatif 0.13.0 (registry+https://github.com/rust-lang/crates.io-index)",
@@ -571,10 +571,10 @@ dependencies = [
 
 [[package]]
 name = "tokenizers-python"
-version = "0.1.0"
+version = "0.1.1"
 dependencies = [
  "pyo3 0.8.4 (registry+https://github.com/rust-lang/crates.io-index)",
- "tokenizers 0.6.0",
+ "tokenizers 0.6.1",
 ]
 
 [[package]]
diff --git a/bindings/python/Cargo.toml b/bindings/python/Cargo.toml
index fbb5785bf..03e583bfd 100644
--- a/bindings/python/Cargo.toml
+++ b/bindings/python/Cargo.toml
@@ -1,6 +1,6 @@
 [package]
 name = "tokenizers-python"
-version = "0.1.0"
+version = "0.1.1"
 authors = ["Anthony MOI <m.anthony.moi@gmail.com>"]
 edition = "2018"
 
diff --git a/bindings/python/setup.py b/bindings/python/setup.py
index 1cc963441..b31427486 100644
--- a/bindings/python/setup.py
+++ b/bindings/python/setup.py
@@ -3,7 +3,7 @@
 
 setup(
     name="tokenizers",
-    version="0.1.0",
+    version="0.1.1",
     description="Fast and Customizable Tokenizers",
     long_description=open("README.md", "r", encoding="utf-8").read(),
     long_description_content_type="text/markdown",
diff --git a/bindings/python/tokenizers/__init__.py b/bindings/python/tokenizers/__init__.py
index b1f509a9c..a1167f485 100644
--- a/bindings/python/tokenizers/__init__.py
+++ b/bindings/python/tokenizers/__init__.py
@@ -1,4 +1,4 @@
-__version__ = "0.1.0"
+__version__ = "0.1.1"
 
 from .tokenizers import Tokenizer, Encoding
 from .tokenizers import decoders
diff --git a/tokenizers/Cargo.toml b/tokenizers/Cargo.toml
index 00c5d0161..cca2a70fc 100644
--- a/tokenizers/Cargo.toml
+++ b/tokenizers/Cargo.toml
@@ -2,7 +2,7 @@
 authors = ["Anthony MOI <m.anthony.moi@gmail.com>"]
 edition = "2018"
 name = "tokenizers"
-version = "0.6.0"
+version = "0.6.1"
 homepage = "https://github.com/huggingface/tokenizers"
 repository = "https://github.com/huggingface/tokenizers"
 documentation = "https://docs.rs/tokenizers/"
diff --git a/tokenizers/src/tokenizer/mod.rs b/tokenizers/src/tokenizer/mod.rs
index 77bbac0c6..a6d2f7927 100644
--- a/tokenizers/src/tokenizer/mod.rs
+++ b/tokenizers/src/tokenizer/mod.rs
@@ -565,6 +565,8 @@ impl Tokenizer {
             }
         }
 
+        self.refresh_added_tokens();
+
         added
     }
 
@@ -591,11 +593,27 @@ impl Tokenizer {
                 .or_insert_with(|| token.clone());
         }
 
+        self.refresh_added_tokens();
+
+        // Return the number of added tokens
+        tokens.len() - ignored
+    }
+
+    fn refresh_added_tokens(&mut self) {
         // We rebuild the regex here everytime on purpose, because the added tokens may
         // have changed
+        let special_tokens = self
+            .special_tokens
+            .keys()
+            .map(|t| AddedToken {
+                content: t.to_owned(),
+                single_word: true,
+            })
+            .collect::<Vec<_>>();
         let added_tokens = self
             .added_tokens
             .keys()
+            .chain(special_tokens.iter())
             .map(|token| {
                 if token.single_word {
                     let first_b = token
@@ -635,9 +653,6 @@ impl Tokenizer {
             self.split_re =
                 Some(regex::Regex::new(&format!(r"({})", added_tokens.join("|"))).unwrap());
         }
-
-        // Return the number of added tokens
-        tokens.len() - ignored
     }
 
     /// Split the given sentence on multiple parts, finding the added tokens and their id in the process
@@ -677,10 +692,13 @@ impl Tokenizer {
                     .into_iter()
                     .map(|(start, end)| unsafe {
                         let s = sentence.get_unchecked(start..end).to_owned();
-                        let id = self.added_tokens.get(&AddedToken {
-                            content: s.clone(),
-                            ..Default::default()
-                        });
+                        let mut id = self.special_tokens.get(&s);
+                        if id.is_none() {
+                            id = self.added_tokens.get(&AddedToken {
+                                content: s.clone(),
+                                ..Default::default()
+                            });
+                        }
                         (s, id.copied())
                     })
                     .collect()