diff --git a/ernie/ernie.py b/ernie/ernie.py index bac03b0..a61c01d 100644 --- a/ernie/ernie.py +++ b/ernie/ernie.py @@ -73,9 +73,11 @@ def tokenizer(self): def load_dataset(self, dataframe=None, validation_split=0.1, + random_state=None, stratify=None, csv_path=None, read_csv_kwargs=None): + if dataframe is None and csv_path is None: raise ValueError @@ -95,11 +97,13 @@ def load_dataset(self, labels, test_size=validation_split, shuffle=True, + random_state=random_state, stratify=stratify ) self._training_features = get_features( self._tokenizer, training_sentences, training_labels) + self._training_size = len(training_sentences) self._validation_features = get_features(