diff --git a/bindings/python/src/pre_tokenizers.rs b/bindings/python/src/pre_tokenizers.rs index 4ae7037da..b4cefe692 100644 --- a/bindings/python/src/pre_tokenizers.rs +++ b/bindings/python/src/pre_tokenizers.rs @@ -1,6 +1,7 @@ use std::sync::{Arc, RwLock}; use pyo3::exceptions; +use pyo3::exceptions::PyException; use pyo3::prelude::*; use pyo3::types::*; use serde::ser::SerializeStruct; @@ -48,12 +49,16 @@ impl PyPreTokenizer { pub(crate) fn get_as_subtype(&self, py: Python<'_>) -> PyResult { let base = self.clone(); - Ok(match &self.pretok { + Ok(match self.pretok { PyPreTokenizerTypeWrapper::Sequence(_) => { Py::new(py, (PySequence {}, base))?.into_py(py) } PyPreTokenizerTypeWrapper::Single(ref inner) => { - match &*inner.as_ref().read().unwrap() { + match &*inner + .as_ref() + .read() + .map_err(|_| PyException::new_err("pre tokenizer rwlock is poisoned"))? + { PyPreTokenizerWrapper::Custom(_) => Py::new(py, base)?.into_py(py), PyPreTokenizerWrapper::Wrapped(inner) => match inner { PreTokenizerWrapper::Whitespace(_) => { @@ -467,20 +472,39 @@ impl PySequence { fn __getitem__(self_: PyRef<'_, Self>, py: Python<'_>, index: usize) -> PyResult> { match &self_.as_ref().pretok { PyPreTokenizerTypeWrapper::Sequence(inner) => match inner.get(index) { - Some(item) => { - PyPreTokenizer::new(PyPreTokenizerTypeWrapper::Single(Arc::clone(item))) - .get_as_subtype(py) - } + Some(item) => PyPreTokenizer::new(PyPreTokenizerTypeWrapper::Single(item.clone())) + .get_as_subtype(py), _ => Err(PyErr::new::( "Index not found", )), }, PyPreTokenizerTypeWrapper::Single(inner) => { - PyPreTokenizer::new(PyPreTokenizerTypeWrapper::Single(Arc::clone(inner))) + PyPreTokenizer::new(PyPreTokenizerTypeWrapper::Single(inner.clone())) .get_as_subtype(py) } } } + + fn __setitem__(self_: PyRef<'_, Self>, index: usize, value: Bound<'_, PyAny>) -> PyResult<()> { + let norm: PyPreTokenizer = value.extract()?; + let PyPreTokenizerTypeWrapper::Single(norm) = norm.pretok else { return Err(PyException::new_err("normalizer should not be a sequence")); }; + match &self_.as_ref().pretok { + PyPreTokenizerTypeWrapper::Sequence(inner) => match inner.get(index) { + Some(item) => { + *item.write().unwrap() = norm.read().unwrap().clone(); + } + _ => { + return Err(PyErr::new::( + "Index not found", + )) + } + }, + PyPreTokenizerTypeWrapper::Single(_) => { + return Err(PyException::new_err("normalizer is not a sequence")) + } + }; + Ok(()) + } } pub(crate) fn from_string(string: String) -> Result {