Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 74 additions & 48 deletions bindings/python/src/encoding.rs
Original file line number Diff line number Diff line change
@@ -1,68 +1,69 @@
use std::sync::Mutex;

use pyo3::exceptions;
use pyo3::prelude::*;
use pyo3::types::*;
use pyo3::IntoPyObjectExt;
use tk::tokenizer::{Offsets, PaddingDirection};
use tk::utils::truncation::TruncationDirection;
use tokenizers as tk;

use crate::error::{deprecation_warning, PyError};

/// The :class:`~tokenizers.Encoding` represents the output of a :class:`~tokenizers.Tokenizer`.
#[pyclass(dict, module = "tokenizers", name = "Encoding")]
#[pyclass(dict, module = "tokenizers", name = "Encoding", frozen)]
#[repr(transparent)]
pub struct PyEncoding {
pub encoding: tk::tokenizer::Encoding,
pub encoding: Mutex<tk::tokenizer::Encoding>,
}

impl From<tk::tokenizer::Encoding> for PyEncoding {
fn from(v: tk::tokenizer::Encoding) -> Self {
Self { encoding: v }
Self {
encoding: Mutex::new(v),
}
}
}

#[pymethods]
impl PyEncoding {
#[new]
#[pyo3(text_signature = None)]
fn new() -> Self {
Self {
encoding: tk::tokenizer::Encoding::default(),
}
fn new(encoding: Option<PyObject>, py: Python) -> PyResult<Self> {
Ok(Self {
encoding: Mutex::new(match encoding {
Some(encoding) => {
let s = encoding.extract::<&[u8]>(py)?;
serde_json::from_slice(s).map_err(|e| {
exceptions::PyException::new_err(format!(
"Error while attempting to interpret bytes as Encoding: {e}"
))
})?
}
None => tk::tokenizer::Encoding::default(),
}),
})
}

fn __getstate__(&self, py: Python) -> PyResult<PyObject> {
let data = serde_json::to_string(&self.encoding).map_err(|e| {
fn __getnewargs__(&self, py: Python) -> PyResult<PyObject> {
let data = serde_json::to_string(&(*self.encoding.lock().unwrap())).map_err(|e| {
exceptions::PyException::new_err(format!(
"Error while attempting to pickle Encoding: {e}"
))
})?;
Ok(PyBytes::new(py, data.as_bytes()).into())
}

fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> {
match state.extract::<&[u8]>(py) {
Ok(s) => {
self.encoding = serde_json::from_slice(s).map_err(|e| {
exceptions::PyException::new_err(format!(
"Error while attempting to unpickle Encoding: {e}"
))
})?;
Ok(())
}
Err(e) => Err(e),
}
PyTuple::new(py, [PyBytes::new(py, data.as_bytes())])?.into_py_any(py)
}

fn __repr__(&self) -> PyResult<String> {
Ok(format!(
"Encoding(num_tokens={}, attributes=[ids, type_ids, tokens, offsets, \
attention_mask, special_tokens_mask, overflowing])",
self.encoding.get_ids().len()
self.encoding.lock().unwrap().get_ids().len()
))
}

fn __len__(&self) -> PyResult<usize> {
Ok(self.encoding.len())
Ok(self.encoding.lock().unwrap().len())
}

/// Merge the list of encodings into one final :class:`~tokenizers.Encoding`
Expand All @@ -81,7 +82,9 @@ impl PyEncoding {
#[pyo3(text_signature = "(encodings, growing_offsets=True)")]
fn merge(encodings: Vec<PyRef<PyEncoding>>, growing_offsets: bool) -> PyEncoding {
tk::tokenizer::Encoding::merge(
encodings.into_iter().map(|e| e.encoding.clone()),
encodings
.into_iter()
.map(|e| e.encoding.lock().unwrap().clone()),
growing_offsets,
)
.into()
Expand All @@ -93,16 +96,16 @@ impl PyEncoding {
/// :obj:`int`: The number of sequences in this :class:`~tokenizers.Encoding`
#[getter]
fn get_n_sequences(&self) -> usize {
self.encoding.n_sequences()
self.encoding.lock().unwrap().n_sequences()
}

/// Set the given sequence index
///
/// Set the given sequence index for the whole range of tokens contained in this
/// :class:`~tokenizers.Encoding`.
#[pyo3(text_signature = "(self, sequence_id)")]
fn set_sequence_id(&mut self, sequence_id: usize) {
self.encoding.set_sequence_id(sequence_id);
fn set_sequence_id(&self, sequence_id: usize) {
self.encoding.lock().unwrap().set_sequence_id(sequence_id);
}

/// The generated IDs
Expand All @@ -114,7 +117,7 @@ impl PyEncoding {
/// :obj:`List[int]`: The list of IDs
#[getter]
fn get_ids(&self) -> Vec<u32> {
self.encoding.get_ids().to_vec()
self.encoding.lock().unwrap().get_ids().to_vec()
}

/// The generated tokens
Expand All @@ -125,7 +128,7 @@ impl PyEncoding {
/// :obj:`List[str]`: The list of tokens
#[getter]
fn get_tokens(&self) -> Vec<String> {
self.encoding.get_tokens().to_vec()
self.encoding.lock().unwrap().get_tokens().to_vec()
}

/// The generated word indices.
Expand Down Expand Up @@ -168,7 +171,7 @@ impl PyEncoding {
/// A :obj:`List` of :obj:`Optional[int]`: A list of optional word index.
#[getter]
fn get_word_ids(&self) -> Vec<Option<u32>> {
self.encoding.get_word_ids().to_vec()
self.encoding.lock().unwrap().get_word_ids().to_vec()
}

/// The generated sequence indices.
Expand All @@ -181,7 +184,7 @@ impl PyEncoding {
/// A :obj:`List` of :obj:`Optional[int]`: A list of optional sequence index.
#[getter]
fn get_sequence_ids(&self) -> Vec<Option<usize>> {
self.encoding.get_sequence_ids()
self.encoding.lock().unwrap().get_sequence_ids()
}

/// The generated type IDs
Expand All @@ -193,7 +196,7 @@ impl PyEncoding {
/// :obj:`List[int]`: The list of type ids
#[getter]
fn get_type_ids(&self) -> Vec<u32> {
self.encoding.get_type_ids().to_vec()
self.encoding.lock().unwrap().get_type_ids().to_vec()
}

/// The offsets associated to each token
Expand All @@ -205,7 +208,7 @@ impl PyEncoding {
/// A :obj:`List` of :obj:`Tuple[int, int]`: The list of offsets
#[getter]
fn get_offsets(&self) -> Vec<(usize, usize)> {
self.encoding.get_offsets().to_vec()
self.encoding.lock().unwrap().get_offsets().to_vec()
}

/// The special token mask
Expand All @@ -216,7 +219,11 @@ impl PyEncoding {
/// :obj:`List[int]`: The special tokens mask
#[getter]
fn get_special_tokens_mask(&self) -> Vec<u32> {
self.encoding.get_special_tokens_mask().to_vec()
self.encoding
.lock()
.unwrap()
.get_special_tokens_mask()
.to_vec()
}

/// The attention mask
Expand All @@ -229,7 +236,7 @@ impl PyEncoding {
/// :obj:`List[int]`: The attention mask
#[getter]
fn get_attention_mask(&self) -> Vec<u32> {
self.encoding.get_attention_mask().to_vec()
self.encoding.lock().unwrap().get_attention_mask().to_vec()
}

/// A :obj:`List` of overflowing :class:`~tokenizers.Encoding`
Expand All @@ -244,6 +251,8 @@ impl PyEncoding {
#[getter]
fn get_overflowing(&self) -> Vec<PyEncoding> {
self.encoding
.lock()
.unwrap()
.get_overflowing()
.clone()
.into_iter()
Expand All @@ -265,7 +274,10 @@ impl PyEncoding {
#[pyo3(signature = (word_index, sequence_index = 0))]
#[pyo3(text_signature = "(self, word_index, sequence_index=0)")]
fn word_to_tokens(&self, word_index: u32, sequence_index: usize) -> Option<(usize, usize)> {
self.encoding.word_to_tokens(word_index, sequence_index)
self.encoding
.lock()
.unwrap()
.word_to_tokens(word_index, sequence_index)
}

/// Get the offsets of the word at the given index in one of the input sequences.
Expand All @@ -281,7 +293,10 @@ impl PyEncoding {
#[pyo3(signature = (word_index, sequence_index = 0))]
#[pyo3(text_signature = "(self, word_index, sequence_index=0)")]
fn word_to_chars(&self, word_index: u32, sequence_index: usize) -> Option<Offsets> {
self.encoding.word_to_chars(word_index, sequence_index)
self.encoding
.lock()
.unwrap()
.word_to_chars(word_index, sequence_index)
}

/// Get the index of the sequence represented by the given token.
Expand All @@ -297,7 +312,7 @@ impl PyEncoding {
/// :obj:`int`: The sequence id of the given token
#[pyo3(text_signature = "(self, token_index)")]
fn token_to_sequence(&self, token_index: usize) -> Option<usize> {
self.encoding.token_to_sequence(token_index)
self.encoding.lock().unwrap().token_to_sequence(token_index)
}

/// Get the offsets of the token at the given index.
Expand All @@ -314,7 +329,7 @@ impl PyEncoding {
/// :obj:`Tuple[int, int]`: The token offsets :obj:`(first, last + 1)`
#[pyo3(text_signature = "(self, token_index)")]
fn token_to_chars(&self, token_index: usize) -> Option<Offsets> {
let (_, offsets) = self.encoding.token_to_chars(token_index)?;
let (_, offsets) = self.encoding.lock().unwrap().token_to_chars(token_index)?;
Some(offsets)
}

Expand All @@ -332,7 +347,7 @@ impl PyEncoding {
/// :obj:`int`: The index of the word in the relevant input sequence.
#[pyo3(text_signature = "(self, token_index)")]
fn token_to_word(&self, token_index: usize) -> Option<u32> {
let (_, word_idx) = self.encoding.token_to_word(token_index)?;
let (_, word_idx) = self.encoding.lock().unwrap().token_to_word(token_index)?;
Some(word_idx)
}

Expand All @@ -349,7 +364,10 @@ impl PyEncoding {
#[pyo3(signature = (char_pos, sequence_index = 0))]
#[pyo3(text_signature = "(self, char_pos, sequence_index=0)")]
fn char_to_token(&self, char_pos: usize, sequence_index: usize) -> Option<usize> {
self.encoding.char_to_token(char_pos, sequence_index)
self.encoding
.lock()
.unwrap()
.char_to_token(char_pos, sequence_index)
}

/// Get the word that contains the char at the given position in the input sequence.
Expand All @@ -365,7 +383,10 @@ impl PyEncoding {
#[pyo3(signature = (char_pos, sequence_index = 0))]
#[pyo3(text_signature = "(self, char_pos, sequence_index=0)")]
fn char_to_word(&self, char_pos: usize, sequence_index: usize) -> Option<u32> {
self.encoding.char_to_word(char_pos, sequence_index)
self.encoding
.lock()
.unwrap()
.char_to_word(char_pos, sequence_index)
}

/// Pad the :class:`~tokenizers.Encoding` at the given length
Expand All @@ -389,7 +410,7 @@ impl PyEncoding {
#[pyo3(
text_signature = "(self, length, direction='right', pad_id=0, pad_type_id=0, pad_token='[PAD]')"
)]
fn pad(&mut self, length: usize, kwargs: Option<&Bound<'_, PyDict>>) -> PyResult<()> {
fn pad(&self, length: usize, kwargs: Option<&Bound<'_, PyDict>>) -> PyResult<()> {
let mut pad_id = 0;
let mut pad_type_id = 0;
let mut pad_token = "[PAD]".to_string();
Expand Down Expand Up @@ -419,6 +440,8 @@ impl PyEncoding {
}
}
self.encoding
.lock()
.unwrap()
.pad(length, pad_id, pad_type_id, &pad_token, direction);
Ok(())
}
Expand All @@ -439,7 +462,7 @@ impl PyEncoding {
/// Truncate direction
#[pyo3(signature = (max_length, stride = 0, direction = "right"))]
#[pyo3(text_signature = "(self, max_length, stride=0, direction='right')")]
fn truncate(&mut self, max_length: usize, stride: usize, direction: &str) -> PyResult<()> {
fn truncate(&self, max_length: usize, stride: usize, direction: &str) -> PyResult<()> {
let tdir = match direction {
"left" => Ok(TruncationDirection::Left),
"right" => Ok(TruncationDirection::Right),
Expand All @@ -449,7 +472,10 @@ impl PyEncoding {
),
}?;

self.encoding.truncate(max_length, stride, tdir);
self.encoding
.lock()
.unwrap()
.truncate(max_length, stride, tdir);
Ok(())
}
}
4 changes: 2 additions & 2 deletions bindings/python/src/processors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -169,8 +169,8 @@ impl PyPostProcessor {
add_special_tokens: bool,
) -> PyResult<PyEncoding> {
let final_encoding = ToPyResult(self.processor.process(
encoding.encoding.clone(),
pair.map(|e| e.encoding.clone()),
encoding.encoding.lock().unwrap().clone(),
pair.map(|e| e.encoding.lock().unwrap().clone()),
add_special_tokens,
))
.into_py()?;
Expand Down
4 changes: 2 additions & 2 deletions bindings/python/src/tokenizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1761,8 +1761,8 @@ impl PyTokenizer {
ToPyResult(
self.tokenizer
.post_process(
encoding.encoding.clone(),
pair.map(|p| p.encoding.clone()),
encoding.encoding.lock().unwrap().clone(),
pair.map(|p| p.encoding.lock().unwrap().clone()),
add_special_tokens,
)
.map(|e| e.into()),
Expand Down