Skip to content

Commit

Permalink
feat: add __setitem__ impl to pre_tokenizer::PySequence
Browse files Browse the repository at this point in the history
  • Loading branch information
McPatate committed Jan 14, 2025
1 parent 25cc22a commit 5cc53bf
Showing 1 changed file with 31 additions and 7 deletions.
38 changes: 31 additions & 7 deletions bindings/python/src/pre_tokenizers.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::sync::{Arc, RwLock};

use pyo3::exceptions;
use pyo3::exceptions::PyException;
use pyo3::prelude::*;
use pyo3::types::*;
use serde::ser::SerializeStruct;
Expand Down Expand Up @@ -48,12 +49,16 @@ impl PyPreTokenizer {

pub(crate) fn get_as_subtype(&self, py: Python<'_>) -> PyResult<PyObject> {
let base = self.clone();
Ok(match &self.pretok {
Ok(match self.pretok {
PyPreTokenizerTypeWrapper::Sequence(_) => {
Py::new(py, (PySequence {}, base))?.into_py(py)
}
PyPreTokenizerTypeWrapper::Single(ref inner) => {
match &*inner.as_ref().read().unwrap() {
match &*inner
.as_ref()
.read()
.map_err(|_| PyException::new_err("pre tokenizer rwlock is poisoned"))?
{
PyPreTokenizerWrapper::Custom(_) => Py::new(py, base)?.into_py(py),
PyPreTokenizerWrapper::Wrapped(inner) => match inner {
PreTokenizerWrapper::Whitespace(_) => {
Expand Down Expand Up @@ -467,20 +472,39 @@ impl PySequence {
fn __getitem__(self_: PyRef<'_, Self>, py: Python<'_>, index: usize) -> PyResult<Py<PyAny>> {
match &self_.as_ref().pretok {
PyPreTokenizerTypeWrapper::Sequence(inner) => match inner.get(index) {
Some(item) => {
PyPreTokenizer::new(PyPreTokenizerTypeWrapper::Single(Arc::clone(item)))
.get_as_subtype(py)
}
Some(item) => PyPreTokenizer::new(PyPreTokenizerTypeWrapper::Single(item.clone()))
.get_as_subtype(py),
_ => Err(PyErr::new::<pyo3::exceptions::PyIndexError, _>(
"Index not found",
)),
},
PyPreTokenizerTypeWrapper::Single(inner) => {
PyPreTokenizer::new(PyPreTokenizerTypeWrapper::Single(Arc::clone(inner)))
PyPreTokenizer::new(PyPreTokenizerTypeWrapper::Single(inner.clone()))
.get_as_subtype(py)
}
}
}

fn __setitem__(self_: PyRef<'_, Self>, index: usize, value: Bound<'_, PyAny>) -> PyResult<()> {
let norm: PyPreTokenizer = value.extract()?;
let PyPreTokenizerTypeWrapper::Single(norm) = norm.pretok else { return Err(PyException::new_err("normalizer should not be a sequence")); };
match &self_.as_ref().pretok {
PyPreTokenizerTypeWrapper::Sequence(inner) => match inner.get(index) {
Some(item) => {
*item.write().unwrap() = norm.read().unwrap().clone();
}
_ => {
return Err(PyErr::new::<pyo3::exceptions::PyIndexError, _>(
"Index not found",
))
}
},
PyPreTokenizerTypeWrapper::Single(_) => {
return Err(PyException::new_err("normalizer is not a sequence"))
}
};
Ok(())
}
}

pub(crate) fn from_string(string: String) -> Result<PrependScheme, PyErr> {
Expand Down

0 comments on commit 5cc53bf

Please sign in to comment.