Skip to content

Commit

Permalink
Fixes.
Browse files Browse the repository at this point in the history
  • Loading branch information
eaplatanios committed Jan 23, 2024
1 parent 25a7f91 commit 625b080
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 0 deletions.
37 changes: 37 additions & 0 deletions bindings/python/src/tokenizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ use super::trainers::PyTrainer;
use crate::processors::PyPostProcessor;
use crate::utils::{MaybeSizedIterator, PyBufferedIterator};
use std::collections::BTreeMap;
use tk::{AddedToken, AddedVocabulary};

/// Represents a token that can be be added to a :class:`~tokenizers.Tokenizer`.
/// It can have special options that defines the way it should behave.
Expand Down Expand Up @@ -662,6 +663,42 @@ impl PyTokenizer {
self.tokenizer.get_vocab(with_added_tokens)
}

/// Sets the underlying added tokens vocabulary
///
/// Args:
/// added_tokens_decoder (:obj:`Dict[int, AddedToken]`):
/// Map from added token ID to :obj:`AddedToken`.
/// encode_special_tokens (:obj:`bool`, defaults to :onj:`False`):
/// Whether or not special tokens should be split when encoding. This is equivalent to ignoring them.
#[pyo3(signature = (added_tokens_decoder, encode_special_tokens = false))]
#[pyo3(text_signature = "(self, added_tokens_decoder, encode_special_tokens=False)")]
fn set_added_tokens_decoder(
&mut self,
added_tokens_decoder: &PyDict,
encode_special_tokens: bool,
) -> PyResult<()> {
added_tokens_decoder
.iter()
.map(|(key, value)| {
key.extract::<u32>().and_then(|key| {
value
.extract::<PyRefMut<PyAddedToken>>()
.map(|value| (key, value.get_token()))
})
})
.collect::<Result<HashMap<u32, AddedToken>, PyErr>>()
.map(|added_tokens| {
self.tokenizer
.with_added_vocabulary(AddedVocabulary::from_indexed_added_tokens(
added_tokens,
encode_special_tokens,
self.tokenizer.get_model(),
self.tokenizer.get_normalizer(),
))
})?;
Ok(())
}

/// Get the underlying vocabulary
///
/// Returns:
Expand Down
50 changes: 50 additions & 0 deletions tokenizers/src/tokenizer/added_vocabulary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,56 @@ impl AddedVocabulary {
encode_special_tokens: false,
}
}

/// Creates a new [AddedVocabulary] from a collection of [AddedToken]s that already have assigned IDs.
/// This constructor is useful for constructing an [AddedVocabulary] from a pre-existing [AddedVocabulary]
/// (e.g., from a serialized [AddedVocabulary]).
pub fn from_indexed_added_tokens<N: Normalizer>(
tokens: HashMap<u32, AddedToken>,
encode_special_tokens: bool,
model: &impl Model,
normalizer: Option<&N>,
) -> Self {
let mut vocabulary = AddedVocabulary::new();
vocabulary.encode_special_tokens = encode_special_tokens;

// Handle special tokens (if any).
for token in tokens.values() {
if token.special
&& !token.content.is_empty()
&& !vocabulary.special_tokens_set.contains(&token.content)
{
vocabulary.special_tokens.push(token.to_owned());
vocabulary.special_tokens_set.insert(token.content.clone());
}
}

for (token_id, token) in tokens {
if token.content.is_empty() || vocabulary.added_tokens_map_r.values().any(|val| *val == token)
{
continue;
}

vocabulary.added_tokens_map
.entry(token.content.clone())
.and_modify(|old_id| *old_id = token_id)
.or_insert_with(|| token_id);

vocabulary.added_tokens_map_r
.entry(token_id)
.and_modify(|t| *t = token.clone())
.or_insert_with(|| token.clone());

if !vocabulary.special_tokens_set.contains(&token.content) {
vocabulary.added_tokens.push(token.clone());
}
}

vocabulary.refresh_added_tokens(model, normalizer);

vocabulary
}

/// Size of the additional vocabulary
#[allow(dead_code)] // Suppress the "method is never used" warning
pub fn len(&self) -> usize {
Expand Down

0 comments on commit 625b080

Please sign in to comment.