Skip to content

Commit c0d3697

Browse files
chore(trainer): add and improve trainer signature (#1838)
* chore(trainers): add __init__ to fix python type check errors * restore * chore(trainer): add and improve trainer signature * clean fix * chore(fmt): fix cargo fmt error --------- Co-authored-by: Arthur <[email protected]>
1 parent c91d76a commit c0d3697

File tree

2 files changed

+27
-4
lines changed

2 files changed

+27
-4
lines changed

bindings/python/py_src/tokenizers/trainers/__init__.pyi

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,20 @@ class BpeTrainer(Trainer):
4545
highly repetitive tokens like `======` for wikipedia
4646
4747
"""
48+
def __init__(
49+
self,
50+
vocab_size=30000,
51+
min_frequency=0,
52+
show_progress=True,
53+
special_tokens=[],
54+
limit_alphabet=None,
55+
initial_alphabet=[],
56+
continuing_subword_prefix=None,
57+
end_of_word_suffix=None,
58+
max_token_length=None,
59+
words={},
60+
):
61+
pass
4862

4963
class UnigramTrainer(Trainer):
5064
"""
@@ -85,6 +99,7 @@ class UnigramTrainer(Trainer):
8599
vocab_size=8000,
86100
show_progress=True,
87101
special_tokens=[],
102+
initial_alphabet=[],
88103
shrinking_factor=0.75,
89104
unk_token=None,
90105
max_piece_length=16,
@@ -109,6 +124,8 @@ class WordLevelTrainer(Trainer):
109124
special_tokens (:obj:`List[Union[str, AddedToken]]`):
110125
A list of special tokens the model should know of.
111126
"""
127+
def __init__(self, vocab_size=30000, min_frequency=0, show_progress=True, special_tokens=[]):
128+
pass
112129

113130
class WordPieceTrainer(Trainer):
114131
"""

bindings/python/src/trainers.rs

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -312,7 +312,10 @@ impl PyBpeTrainer {
312312
}
313313

314314
#[new]
315-
#[pyo3(signature = (**kwargs), text_signature = None)]
315+
#[pyo3(
316+
signature = (**kwargs),
317+
text_signature = "(self, vocab_size=30000, min_frequency=0, show_progress=True, special_tokens=[], limit_alphabet=None, initial_alphabet=[], continuing_subword_prefix=None, end_of_word_suffix=None, max_token_length=None, words={})"
318+
)]
316319
pub fn new(kwargs: Option<&Bound<'_, PyDict>>) -> PyResult<(Self, PyTrainer)> {
317320
let mut builder = tk::models::bpe::BpeTrainer::builder();
318321
if let Some(kwargs) = kwargs {
@@ -518,7 +521,7 @@ impl PyWordPieceTrainer {
518521
#[new]
519522
#[pyo3(
520523
signature = (** kwargs),
521-
text_signature = "(self, vocab_size=30000, min_frequency=0, show_progress=True, special_tokens=[], limit_alphabet=None, initial_alphabet= [],continuing_subword_prefix=\"##\", end_of_word_suffix=None)"
524+
text_signature = "(self, vocab_size=30000, min_frequency=0, show_progress=True, special_tokens=[], limit_alphabet=None, initial_alphabet=[], continuing_subword_prefix=\"##\", end_of_word_suffix=None)"
522525
)]
523526
pub fn new(kwargs: Option<&Bound<'_, PyDict>>) -> PyResult<(Self, PyTrainer)> {
524527
let mut builder = tk::models::wordpiece::WordPieceTrainer::builder();
@@ -659,7 +662,10 @@ impl PyWordLevelTrainer {
659662
}
660663

661664
#[new]
662-
#[pyo3(signature = (**kwargs), text_signature = None)]
665+
#[pyo3(
666+
signature = (**kwargs),
667+
text_signature = "(self, vocab_size=30000, min_frequency=0, show_progress=True, special_tokens=[])"
668+
)]
663669
pub fn new(kwargs: Option<&Bound<'_, PyDict>>) -> PyResult<(Self, PyTrainer)> {
664670
let mut builder = tk::models::wordlevel::WordLevelTrainer::builder();
665671

@@ -826,7 +832,7 @@ impl PyUnigramTrainer {
826832
#[new]
827833
#[pyo3(
828834
signature = (**kwargs),
829-
text_signature = "(self, vocab_size=8000, show_progress=True, special_tokens=[], shrinking_factor=0.75, unk_token=None, max_piece_length=16, n_sub_iterations=2)"
835+
text_signature = "(self, vocab_size=8000, show_progress=True, special_tokens=[], initial_alphabet=[], shrinking_factor=0.75, unk_token=None, max_piece_length=16, n_sub_iterations=2)"
830836
)]
831837
pub fn new(kwargs: Option<Bound<'_, PyDict>>) -> PyResult<(Self, PyTrainer)> {
832838
let mut builder = tk::models::unigram::UnigramTrainer::builder();

0 commit comments

Comments
 (0)