Skip to content

Commit

Permalink
Added reproducibility
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcosFP97 committed Dec 23, 2021
1 parent c457f28 commit 847ea8f
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions ernie/ernie.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def model(self):
def tokenizer(self):
return self._tokenizer

def load_dataset(self, dataframe=None, validation_split=0.1, stratify=None, csv_path=None, read_csv_kwargs=None):
def load_dataset(self, dataframe=None, validation_split=0.1, random_state=None, stratify=None, csv_path=None, read_csv_kwargs=None):
if dataframe is None and csv_path is None:
raise ValueError

Expand All @@ -69,7 +69,7 @@ def load_dataset(self, dataframe=None, validation_split=0.1, stratify=None, csv_
labels = dataframe[dataframe.columns[1]].values

training_sentences, validation_sentences, training_labels, validation_labels = train_test_split(
sentences, labels, test_size=validation_split, shuffle=True, stratify=stratify)
sentences, labels, test_size=validation_split, shuffle=True, random_state=random_state, stratify=stratify)

self._training_features = get_features(self._tokenizer, training_sentences, training_labels)
self._training_size = len(training_sentences)
Expand Down

0 comments on commit 847ea8f

Please sign in to comment.