From c4aca41fcea88b07f37fd06dc3a334e5eea5c813 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rodrigo=20Mart=C3=ADnez?= Date: Wed, 9 Sep 2020 13:25:35 +0200 Subject: [PATCH 1/8] Update version --- ernie/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ernie/__init__.py b/ernie/__init__.py index 7836894..fa2b07c 100644 --- a/ernie/__init__.py +++ b/ernie/__init__.py @@ -5,7 +5,7 @@ from tensorflow.python.client import device_lib import logging -__version__ = '0.0.28b0' +__version__ = '0.0.32b0' logging.getLogger().setLevel(logging.WARNING) logging.getLogger("transformers.tokenization_utils").setLevel(logging.ERROR) From d99e2a5c377fbaea0804e9446c71bd3a8fa5be72 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rodrigo=20Mart=C3=ADnez?= Date: Wed, 9 Sep 2020 13:58:31 +0200 Subject: [PATCH 2/8] Config object was not instantiated when loading a local model --- ernie/__init__.py | 2 +- ernie/ernie.py | 15 +++++++++------ 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/ernie/__init__.py b/ernie/__init__.py index fa2b07c..6070a57 100644 --- a/ernie/__init__.py +++ b/ernie/__init__.py @@ -5,7 +5,7 @@ from tensorflow.python.client import device_lib import logging -__version__ = '0.0.32b0' +__version__ = '0.0.33b0' logging.getLogger().setLevel(logging.WARNING) logging.getLogger("transformers.tokenization_utils").setLevel(logging.ERROR) diff --git a/ernie/ernie.py b/ernie/ernie.py index afda304..9f40cda 100644 --- a/ernie/ernie.py +++ b/ernie/ernie.py @@ -58,13 +58,18 @@ def model(self): def tokenizer(self): return self._tokenizer - def load_dataset(self, dataframe=None, validation_split=0.1, stratify=None, csv_path=None, read_csv_kwargs=None): + def load_dataset(self, + dataframe=None, + validation_split=0.1, + stratify=None, + csv_path=None, + read_csv_kwargs=None): if dataframe is None and csv_path is None: raise ValueError if csv_path is not None: dataframe = pd.read_csv(csv_path, **read_csv_kwargs) - + sentences = list(dataframe[dataframe.columns[0]]) labels = dataframe[dataframe.columns[1]].values @@ -236,9 +241,11 @@ def _reload_model(self): def _load_local_model(self, model_path): try: self._tokenizer = AutoTokenizer.from_pretrained(model_path + '/tokenizer') + self._config = AutoConfig.from_pretrained(model_path + '/tokenizer') # Old models didn't use to have a tokenizer folder except OSError: self._tokenizer = AutoTokenizer.from_pretrained(model_path) + self._config = AutoConfig.from_pretrained(model_path) self._model = TFAutoModelForSequenceClassification.from_pretrained(model_path, from_pt=False) @@ -252,10 +259,6 @@ def _load_remote_model(self, model_name, tokenizer_kwargs, model_kwargs): do_lower_case = True tokenizer_kwargs.update({'do_lower_case': do_lower_case}) - self._tokenizer = None - self._model = None - self._config = None - self._tokenizer = AutoTokenizer.from_pretrained(model_name, **tokenizer_kwargs) self._config = AutoConfig.from_pretrained(model_name) From f598d5ca92424b7ca9a4faf646f485a2a4f5c9a4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rodrigo=20Mart=C3=ADnez?= Date: Tue, 10 Nov 2020 18:50:54 +0100 Subject: [PATCH 3/8] Fix Apache License, close #19 --- setup.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/setup.py b/setup.py index 019b037..e4d5a23 100755 --- a/setup.py +++ b/setup.py @@ -12,14 +12,14 @@ url='https://github.com/brunneis/ernie', author='Rodrigo Martínez Castaño', author_email='rodrigo@martinez.gal', - license='GNU General Public License v3 (GPLv3)', + license='Apache License (Version 2.0)', packages=find_packages(), zip_safe=False, classifiers=[ "Development Status :: 4 - Beta", "Environment :: Console", "Intended Audience :: Developers", - "License :: OSI Approved :: GNU General Public License v3 (GPLv3)", + "License :: OSI Approved :: Apache Software License", "Operating System :: POSIX :: Linux", "Programming Language :: Python :: 3.6", "Programming Language :: Python :: Implementation :: PyPy", @@ -30,4 +30,4 @@ 'scikit-learn>=0.22.1', 'pandas>=0.25.3', 'tensorflow>=2.1.0,!=2.2.0-rc0,!=2.2.0rc1', - 'py-cpuinfo==5.0.0']) \ No newline at end of file + 'py-cpuinfo==5.0.0']) From 96a153644c15888d0681a6b484028c6ada71319c Mon Sep 17 00:00:00 2001 From: brunneis Date: Wed, 22 Dec 2021 20:44:08 +0100 Subject: [PATCH 4/8] update dependencies --- requirements.txt | 10 +++++----- setup.py | 17 +++++++++++------ 2 files changed, 16 insertions(+), 11 deletions(-) diff --git a/requirements.txt b/requirements.txt index a272669..efa724d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ -transformers==2.4.1 -scikit-learn>=0.22.1 -pandas>=0.25.3 -tensorflow>=2.1.0,!=2.2.0-rc0,!=2.2.0rc1 -py-cpuinfo==5.0.0 \ No newline at end of file +transformers>=2.4.1, < 2.5.0 +scikit-learn>=0.22.1, < 1.0.0 +pandas>=0.25.3, < 1.0.0 +tensorflow>=2.5.1, < 2.6.0 +py-cpuinfo>=5.0.0, < 6.0.0 diff --git a/setup.py b/setup.py index e4d5a23..e20f825 100755 --- a/setup.py +++ b/setup.py @@ -8,7 +8,10 @@ setup( name='ernie', version=ernie.__version__, - description='An Accessible Python Library for State-of-the-art Natural Language Processing. Built with HuggingFace\'s Transformers.', + description=( + 'An Accessible Python Library for State-of-the-art ' + 'Natural Language Processing. Built with HuggingFace\'s Transformers.' + ), url='https://github.com/brunneis/ernie', author='Rodrigo Martínez Castaño', author_email='rodrigo@martinez.gal', @@ -26,8 +29,10 @@ "Topic :: Software Development :: Libraries :: Python Modules", ], python_requires=">=3.6", - install_requires=['transformers==2.4.1', - 'scikit-learn>=0.22.1', - 'pandas>=0.25.3', - 'tensorflow>=2.1.0,!=2.2.0-rc0,!=2.2.0rc1', - 'py-cpuinfo==5.0.0']) + install_requires=[ + 'transformers>=2.4.1, < 2.5.0', + 'scikit-learn>=0.22.1, < 1.0.0', + 'pandas>=0.25.3, < 1.0.0', + 'tensorflow>=2.5.1, < 2.6.0', + 'py-cpuinfo>=5.0.0, < 6.0.0' + ]) From c6552c9ec2d238b5a09f0526fb08227e6b927a13 Mon Sep 17 00:00:00 2001 From: brunneis Date: Wed, 22 Dec 2021 21:12:35 +0100 Subject: [PATCH 5/8] format code --- ernie/__init__.py | 17 ++- ernie/aggregation_strategies.py | 57 ++++++---- ernie/ernie.py | 177 ++++++++++++++++++++++---------- ernie/helper.py | 21 ++-- ernie/models.py | 22 ++-- ernie/split_strategies.py | 48 +++++---- examples/binary_classifier.py | 17 ++- test/__init__.py | 2 +- test/dump_load.py | 8 +- test/load_csv.py | 24 ++++- test/load_model.py | 21 ++-- test/predict.py | 13 ++- test/split_aggregate.py | 158 ++++++++++++++++------------ 13 files changed, 389 insertions(+), 196 deletions(-) diff --git a/ernie/__init__.py b/ernie/__init__.py index 6070a57..0d62b70 100644 --- a/ernie/__init__.py +++ b/ernie/__init__.py @@ -1,15 +1,18 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- -from .ernie import * +from .ernie import * # noqa: F401, F403 from tensorflow.python.client import device_lib import logging -__version__ = '0.0.33b0' +__version__ = '1.0.0' logging.getLogger().setLevel(logging.WARNING) logging.getLogger("transformers.tokenization_utils").setLevel(logging.ERROR) -logging.basicConfig(format='%(asctime)-15s [%(levelname)s] %(message)s', datefmt='%Y-%m-%d %H:%M:%S') +logging.basicConfig( + format='%(asctime)-15s [%(levelname)s] %(message)s', + datefmt='%Y-%m-%d %H:%M:%S' +) def _get_cpu_name(): @@ -20,7 +23,13 @@ def _get_cpu_name(): def _get_gpu_name(): - gpu_name = device_lib.list_local_devices()[3].physical_device_desc.split(',')[1].split('name:')[1].strip() + gpu_name = \ + device_lib\ + .list_local_devices()[3]\ + .physical_device_desc\ + .split(',')[1]\ + .split('name:')[1]\ + .strip() return gpu_name diff --git a/ernie/aggregation_strategies.py b/ernie/aggregation_strategies.py index 33e8d24..d6bb578 100644 --- a/ernie/aggregation_strategies.py +++ b/ernie/aggregation_strategies.py @@ -5,7 +5,13 @@ class AggregationStrategy: - def __init__(self, method, max_items=None, top_items=True, sorting_class_index=1): + def __init__( + self, + method, + max_items=None, + top_items=True, + sorting_class_index=1 + ): self.method = method self.max_items = max_items self.top_items = top_items @@ -20,32 +26,45 @@ def aggregate(self, softmax_tuples): softmax_dicts.append(softmax_dict) if self.max_items is not None: - softmax_dicts = sorted(softmax_dicts, key=lambda x: x[self.sorting_class_index], reverse=self.top_items) + softmax_dicts = sorted( + softmax_dicts, + key=lambda x: x[self.sorting_class_index], + reverse=self.top_items + ) if self.max_items < len(softmax_dicts): softmax_dicts = softmax_dicts[:self.max_items] softmax_list = [] for key in softmax_dicts[0].keys(): - softmax_list.append(self.method([probabilities[key] for probabilities in softmax_dicts])) + softmax_list.append(self.method( + [probabilities[key] for probabilities in softmax_dicts])) softmax_tuple = tuple(softmax_list) return softmax_tuple class AggregationStrategies: Mean = AggregationStrategy(method=mean) - MeanTopFiveBinaryClassification = AggregationStrategy(method=mean, - max_items=5, - top_items=True, - sorting_class_index=1) - MeanTopTenBinaryClassification = AggregationStrategy(method=mean, - max_items=10, - top_items=True, - sorting_class_index=1) - MeanTopFifteenBinaryClassification = AggregationStrategy(method=mean, - max_items=15, - top_items=True, - sorting_class_index=1) - MeanTopTwentyBinaryClassification = AggregationStrategy(method=mean, - max_items=20, - top_items=True, - sorting_class_index=1) \ No newline at end of file + MeanTopFiveBinaryClassification = AggregationStrategy( + method=mean, + max_items=5, + top_items=True, + sorting_class_index=1 + ) + MeanTopTenBinaryClassification = AggregationStrategy( + method=mean, + max_items=10, + top_items=True, + sorting_class_index=1 + ) + MeanTopFifteenBinaryClassification = AggregationStrategy( + method=mean, + max_items=15, + top_items=True, + sorting_class_index=1 + ) + MeanTopTwentyBinaryClassification = AggregationStrategy( + method=mean, + max_items=20, + top_items=True, + sorting_class_index=1 + ) diff --git a/ernie/ernie.py b/ernie/ernie.py index 9f40cda..bac03b0 100644 --- a/ernie/ernie.py +++ b/ernie/ernie.py @@ -1,7 +1,6 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- -import tensorflow as tf import numpy as np import pandas as pd from transformers import ( @@ -14,10 +13,23 @@ from sklearn.model_selection import train_test_split import logging import time -from .models import Models, ModelsByFamily -from .split_strategies import SplitStrategy, SplitStrategies, RegexExpressions -from .aggregation_strategies import AggregationStrategy, AggregationStrategies -from .helper import (get_features, softmax, remove_dir, make_dir, copy_dir) +from .models import Models, ModelsByFamily # noqa: F401 +from .split_strategies import ( # noqa: F401 + SplitStrategy, + SplitStrategies, + RegexExpressions +) +from .aggregation_strategies import ( # noqa: F401 + AggregationStrategy, + AggregationStrategies +) +from .helper import ( + get_features, + softmax, + remove_dir, + make_dir, + copy_dir +) AUTOSAVE_PATH = './ernie-autosave/' @@ -73,14 +85,28 @@ def load_dataset(self, sentences = list(dataframe[dataframe.columns[0]]) labels = dataframe[dataframe.columns[1]].values - training_sentences, validation_sentences, training_labels, validation_labels = train_test_split( - sentences, labels, test_size=validation_split, shuffle=True, stratify=stratify) - - self._training_features = get_features(self._tokenizer, training_sentences, training_labels) + ( + training_sentences, + validation_sentences, + training_labels, + validation_labels + ) = train_test_split( + sentences, + labels, + test_size=validation_split, + shuffle=True, + stratify=stratify + ) + + self._training_features = get_features( + self._tokenizer, training_sentences, training_labels) self._training_size = len(training_sentences) - self._validation_features = get_features(self._tokenizer, validation_sentences, - validation_labels) + self._validation_features = get_features( + self._tokenizer, + validation_sentences, + validation_labels + ) self._validation_split = len(validation_sentences) logging.info(f'training_size: {self._training_size}') @@ -125,7 +151,8 @@ def fine_tune(self, training_features = self._training_features.shuffle( self._training_size).batch(training_batch_size).repeat(-1) - validation_features = self._validation_features.batch(validation_batch_size) + validation_features = self._validation_features.batch( + validation_batch_size) training_steps = self._training_size // training_batch_size if training_steps == 0: @@ -137,8 +164,6 @@ def fine_tune(self, validation_steps = self._validation_split logging.info(f'validation_steps: {validation_steps}') - temporary_path = self._get_temporary_path(name=self._model.name) - for i in range(epochs): self._model.fit(training_features, epochs=1, @@ -147,18 +172,29 @@ def fine_tune(self, validation_steps=validation_steps, **kwargs) - # The fine-tuned model does not have the same input interface after being - # exported and loaded again. + # The fine-tuned model does not have the same input interface + # after being exported and loaded again. self._reload_model() - def predict_one(self, text, split_strategy=None, aggregation_strategy=None): + def predict_one( + self, + text, + split_strategy=None, + aggregation_strategy=None + ): return next( self.predict([text], batch_size=1, split_strategy=split_strategy, aggregation_strategy=aggregation_strategy)) - def predict(self, texts, batch_size=32, split_strategy=None, aggregation_strategy=None): + def predict( + self, + texts, + batch_size=32, + split_strategy=None, + aggregation_strategy=None + ): if split_strategy is None: yield from self._predict_batch(texts, batch_size) @@ -178,7 +214,9 @@ def predict(self, texts, batch_size=32, split_strategy=None, aggregation_strateg predictions = list(self._predict_batch(sentences, batch_size)) for i, split_index in enumerate(split_indexes[:-1]): stop_index = split_indexes[i + 1] - yield aggregation_strategy.aggregate(predictions[split_index:stop_index]) + yield aggregation_strategy.aggregate( + predictions[split_index:stop_index] + ) def dump(self, path): if self._model_path: @@ -203,16 +241,24 @@ def _predict_batch(self, sentences: list, batch_size: int): attention_mask_list = [] stop_index = i + batch_size - stop_index = stop_index if stop_index < sentences_number else sentences_number + stop_index = stop_index if stop_index < sentences_number \ + else sentences_number + for j in range(i, stop_index): - features = self._tokenizer.encode_plus(sentences[j], - add_special_tokens=True, - max_length=self._tokenizer.max_len) - input_ids, _, attention_mask = features['input_ids'], features[ - 'token_type_ids'], features['attention_mask'] + features = self._tokenizer.encode_plus( + sentences[j], + add_special_tokens=True, + max_length=self._tokenizer.max_len + ) + input_ids, _, attention_mask = ( + features['input_ids'], + features['token_type_ids'], + features['attention_mask'] + ) input_ids = self._list_to_padded_array(features['input_ids']) - attention_mask = self._list_to_padded_array(features['attention_mask']) + attention_mask = self._list_to_padded_array( + features['attention_mask']) input_ids_list.append(input_ids) attention_mask_list.append(attention_mask) @@ -222,7 +268,10 @@ def _predict_batch(self, sentences: list, batch_size: int): 'attention_mask': np.array(attention_mask_list) } logit_predictions = self._model.predict_on_batch(input_dict) - yield from ([softmax(logit_prediction) for logit_prediction in logit_predictions[0]]) + yield from ( + [softmax(logit_prediction) + for logit_prediction in logit_predictions[0]] + ) def _list_to_padded_array(self, items): array = np.array(items) @@ -234,20 +283,26 @@ def _get_temporary_path(self, name=''): return f'{AUTOSAVE_PATH}{name}/{int(round(time.time() * 1000))}' def _reload_model(self): - self._model_path = self._get_temporary_path(name=self._get_model_family()) + self._model_path = self._get_temporary_path( + name=self._get_model_family()) self._dump(self._model_path) self._load_local_model(self._model_path) def _load_local_model(self, model_path): try: - self._tokenizer = AutoTokenizer.from_pretrained(model_path + '/tokenizer') - self._config = AutoConfig.from_pretrained(model_path + '/tokenizer') + self._tokenizer = AutoTokenizer.from_pretrained( + model_path + '/tokenizer') + self._config = AutoConfig.from_pretrained( + model_path + '/tokenizer') + # Old models didn't use to have a tokenizer folder except OSError: self._tokenizer = AutoTokenizer.from_pretrained(model_path) self._config = AutoConfig.from_pretrained(model_path) - self._model = TFAutoModelForSequenceClassification.from_pretrained(model_path, - from_pt=False) + self._model = TFAutoModelForSequenceClassification.from_pretrained( + model_path, + from_pt=False + ) def _get_model_family(self): model_family = ''.join(self._model.name[2:].split('_')[:2]) @@ -259,7 +314,8 @@ def _load_remote_model(self, model_name, tokenizer_kwargs, model_kwargs): do_lower_case = True tokenizer_kwargs.update({'do_lower_case': do_lower_case}) - self._tokenizer = AutoTokenizer.from_pretrained(model_name, **tokenizer_kwargs) + self._tokenizer = AutoTokenizer.from_pretrained( + model_name, **tokenizer_kwargs) self._config = AutoConfig.from_pretrained(model_name) temporary_path = self._get_temporary_path() @@ -267,21 +323,30 @@ def _load_remote_model(self, model_name, tokenizer_kwargs, model_kwargs): # TensorFlow model try: - self._model = TFAutoModelForSequenceClassification.from_pretrained(model_name, - from_pt=False) + self._model = TFAutoModelForSequenceClassification.from_pretrained( + model_name, + from_pt=False + ) # PyTorch model except TypeError: try: - self._model = TFAutoModelForSequenceClassification.from_pretrained(model_name, - from_pt=True) - - # Loading a TF model from a PyTorch checkpoint is not supported when using a model identifier name + self._model = \ + TFAutoModelForSequenceClassification.from_pretrained( + model_name, + from_pt=True + ) + + # Loading a TF model from a PyTorch checkpoint is not supported + # when using a model identifier name except OSError: model = AutoModel.from_pretrained(model_name) model.save_pretrained(temporary_path) - self._model = TFAutoModelForSequenceClassification.from_pretrained(temporary_path, - from_pt=True) + self._model = \ + TFAutoModelForSequenceClassification.from_pretrained( + temporary_path, + from_pt=True + ) # Clean the model's last layer if the provided properties are different clean_last_layer = False @@ -295,28 +360,34 @@ def _load_remote_model(self, model_name, tokenizer_kwargs, model_kwargs): break if clean_last_layer: - model_family = self._get_model_family() try: - getattr(self._model, self._get_model_family()).save_pretrained(temporary_path) - self._model = self._model.__class__.from_pretrained(temporary_path, - from_pt=False, - **model_kwargs) + getattr(self._model, self._get_model_family() + ).save_pretrained(temporary_path) + self._model = self._model.__class__.from_pretrained( + temporary_path, + from_pt=False, + **model_kwargs + ) # The model is itself the main layer except AttributeError: # TensorFlow model try: - self._model = self._model.__class__.from_pretrained(model_name, - from_pt=False, - **model_kwargs) + self._model = self._model.__class__.from_pretrained( + model_name, + from_pt=False, + **model_kwargs + ) # PyTorch Model except (OSError, TypeError): model = AutoModel.from_pretrained(model_name) model.save_pretrained(temporary_path) - self._model = self._model.__class__.from_pretrained(temporary_path, - from_pt=True, - **model_kwargs) + self._model = self._model.__class__.from_pretrained( + temporary_path, + from_pt=True, + **model_kwargs + ) remove_dir(temporary_path) - assert self._tokenizer and self._model \ No newline at end of file + assert self._tokenizer and self._model diff --git a/ernie/helper.py b/ernie/helper.py index fe89c45..aa2483c 100644 --- a/ernie/helper.py +++ b/ernie/helper.py @@ -10,21 +10,30 @@ def get_features(tokenizer, sentences, labels): features = [] for i, sentence in enumerate(sentences): - inputs = tokenizer.encode_plus(sentence, add_special_tokens=True, max_length=tokenizer.max_len) - input_ids, token_type_ids = inputs['input_ids'], inputs['token_type_ids'] + inputs = tokenizer.encode_plus( + sentence, + add_special_tokens=True, + max_length=tokenizer.max_len + ) + input_ids, token_type_ids = \ + inputs['input_ids'], inputs['token_type_ids'] padding_length = tokenizer.max_len - len(input_ids) if tokenizer.padding_side == 'right': attention_mask = [1] * len(input_ids) + [0] * padding_length input_ids = input_ids + [tokenizer.pad_token_id] * padding_length - token_type_ids = token_type_ids + [tokenizer.pad_token_type_id] * padding_length + token_type_ids = token_type_ids + \ + [tokenizer.pad_token_type_id] * padding_length else: attention_mask = [0] * padding_length + [1] * len(input_ids) input_ids = [tokenizer.pad_token_id] * padding_length + input_ids - token_type_ids = [tokenizer.pad_token_type_id] * padding_length + token_type_ids + token_type_ids = \ + [tokenizer.pad_token_type_id] * padding_length + token_type_ids - assert tokenizer.max_len == len(attention_mask) == len(input_ids) == len( - token_type_ids), f'{tokenizer.max_len}, {len(attention_mask)}, {len(input_ids)}, {len(token_type_ids)}' + assert tokenizer.max_len \ + == len(attention_mask) \ + == len(input_ids) \ + == len(token_type_ids) feature = { 'input_ids': input_ids, diff --git a/ernie/models.py b/ernie/models.py index 0275c49..81596b8 100644 --- a/ernie/models.py +++ b/ernie/models.py @@ -29,13 +29,23 @@ class Models: class ModelsByFamily: - Bert = set([Models.BertBaseUncased, Models.BertBaseCased, Models.BertLargeUncased, Models.BertLargeCased]) + Bert = set([Models.BertBaseUncased, Models.BertBaseCased, + Models.BertLargeUncased, Models.BertLargeCased]) Roberta = set([Models.RobertaBaseCased, Models.RobertaLargeCased]) XLNet = set([Models.XLNetBaseCased, Models.XLNetLargeCased]) - DistilBert = set([Models.DistilBertBaseUncased, Models.DistilBertBaseMultilingualCased]) + DistilBert = set([Models.DistilBertBaseUncased, + Models.DistilBertBaseMultilingualCased]) Albert = set([ - Models.AlbertBaseCased, Models.AlbertLargeCased, Models.AlbertXLargeCased, Models.AlbertXXLargeCased, - Models.AlbertBaseCased2, Models.AlbertLargeCased2, Models.AlbertXLargeCased2, Models.AlbertXXLargeCased2 + Models.AlbertBaseCased, + Models.AlbertLargeCased, + Models.AlbertXLargeCased, + Models.AlbertXXLargeCased, + Models.AlbertBaseCased2, + Models.AlbertLargeCased2, + Models.AlbertXLargeCased2, + Models.AlbertXXLargeCased2 + ]) + Supported = set([ + getattr(Models, model_type) for model_type + in filter(lambda x: x[:2] != '__', Models.__dict__.keys()) ]) - Supported = set( - [getattr(Models, model_type) for model_type in filter(lambda x: x[:2] != '__', Models.__dict__.keys())]) \ No newline at end of file diff --git a/ernie/split_strategies.py b/ernie/split_strategies.py index 5f2a31c..4a0b47f 100644 --- a/ernie/split_strategies.py +++ b/ernie/split_strategies.py @@ -11,18 +11,27 @@ class RegexExpressions: split_by_comma = re.compile(r'[^,]+(?:\,\s*)?') url = re.compile( - r'https?:\/\/(www\.)?[-a-zA-Z0-9@:%._\+~#=]{1,256}\.[a-zA-Z0-9()]{1,6}\b([-a-zA-Z0-9()@:%_\+.~#?&//=]*)') + r'https?:\/\/(www\.)?[-a-zA-Z0-9@:%._\+~#=]{1,256}\.[a-zA-Z0-9()]{1,6}' + r'\b([-a-zA-Z0-9()@:%_\+.~#?&//=]*)' + ) domain = re.compile(r'\w+\.\w+') class SplitStrategy: - def __init__(self, split_patterns, remove_patterns=None, group_splits=True, remove_too_short_groups=True): + def __init__( + self, + split_patterns, + remove_patterns=None, + group_splits=True, + remove_too_short_groups=True + ): if not isinstance(split_patterns, list): self.split_patterns = [split_patterns] else: self.split_patterns = split_patterns - if remove_patterns is not None and not isinstance(remove_patterns, list): + if remove_patterns is not None \ + and not isinstance(remove_patterns, list): self.remove_patterns = [remove_patterns] else: self.remove_patterns = remove_patterns @@ -57,21 +66,20 @@ def len_in_tokens(text_): for split in splits: if len_in_tokens(split) > max_tokens: if len(split_patterns) > 1: - sub_splits = self.split(split, tokenizer, split_patterns[1:]) + sub_splits = self.split( + split, tokenizer, split_patterns[1:]) selected_splits.extend(sub_splits) - else: selected_splits.append(split) else: if not self.group_splits: selected_splits.append(split) - else: - new_aggregated_splits = f'{aggregated_splits} {split}'.strip() + new_aggregated_splits = \ + f'{aggregated_splits} {split}'.strip() if len_in_tokens(new_aggregated_splits) <= max_tokens: aggregated_splits = new_aggregated_splits - else: selected_splits.append(aggregated_splits) aggregated_splits = split @@ -80,8 +88,8 @@ def len_in_tokens(text_): selected_splits.append(aggregated_splits) remove_too_short_groups = len(selected_splits) > 1 \ - and self.group_splits \ - and self.remove_too_short_groups + and self.group_splits \ + and self.remove_too_short_groups if not remove_too_short_groups: final_splits = selected_splits @@ -97,17 +105,21 @@ def len_in_tokens(text_): class SplitStrategies: SentencesWithoutUrls = SplitStrategy(split_patterns=[ - RegexExpressions.split_by_dot, RegexExpressions.split_by_semicolon, RegexExpressions.split_by_colon, + RegexExpressions.split_by_dot, + RegexExpressions.split_by_semicolon, + RegexExpressions.split_by_colon, RegexExpressions.split_by_comma ], - remove_patterns=[RegexExpressions.url, RegexExpressions.domain], - remove_too_short_groups=False, - group_splits=False) + remove_patterns=[RegexExpressions.url, RegexExpressions.domain], + remove_too_short_groups=False, + group_splits=False) GroupedSentencesWithoutUrls = SplitStrategy(split_patterns=[ - RegexExpressions.split_by_dot, RegexExpressions.split_by_semicolon, RegexExpressions.split_by_colon, + RegexExpressions.split_by_dot, + RegexExpressions.split_by_semicolon, + RegexExpressions.split_by_colon, RegexExpressions.split_by_comma ], - remove_patterns=[RegexExpressions.url, RegexExpressions.domain], - remove_too_short_groups=True, - group_splits=True) + remove_patterns=[RegexExpressions.url, RegexExpressions.domain], + remove_too_short_groups=True, + group_splits=True) diff --git a/examples/binary_classifier.py b/examples/binary_classifier.py index d1747a4..a3fe821 100644 --- a/examples/binary_classifier.py +++ b/examples/binary_classifier.py @@ -4,15 +4,22 @@ from ernie import SentenceClassifier, Models import pandas as pd -tuples = [("This is a positive example. I'm very happy today.", 1), - ("This is a negative sentence. Everything was wrong today at work.", 0)] +tuples = [ + ("This is a positive example. I'm very happy today.", 1), + ("This is a negative sentence. Everything was wrong today at work.", 0) +] df = pd.DataFrame(tuples) -classifier = SentenceClassifier(model_name=Models.BertBaseUncased, max_length=128, labels_no=2) +classifier = SentenceClassifier( + model_name=Models.BertBaseUncased, max_length=128, labels_no=2) classifier.load_dataset(df, validation_split=0.2) -classifier.fine_tune(epochs=4, learning_rate=2e-5, training_batch_size=32, validation_batch_size=64) +classifier.fine_tune(epochs=4, learning_rate=2e-5, + training_batch_size=32, validation_batch_size=64) sentence = "Oh, that's great!" probability = classifier.predict_one(sentence)[1] -print(f"\"{sentence}\": {probability} [{'positive' if probability >= 0.5 else 'negative'}]") \ No newline at end of file +print( + f"\"{sentence}\": {probability} " + f"[{'positive' if probability >= 0.5 else 'negative'}]" +) diff --git a/test/__init__.py b/test/__init__.py index 55e5d53..fbac284 100644 --- a/test/__init__.py +++ b/test/__init__.py @@ -1,4 +1,4 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- -from ernie import * \ No newline at end of file +from ernie import * # noqa: F401, F403 diff --git a/test/dump_load.py b/test/dump_load.py index 46ee2a9..2fc8d60 100644 --- a/test/dump_load.py +++ b/test/dump_load.py @@ -7,8 +7,10 @@ class TestDumpAndLoad(unittest.TestCase): - df = pd.DataFrame([("This is a positive example. I'm very happy today.", 1), - ("This is a negative sentence. Everything was wrong today at work.", 0)]) + df = pd.DataFrame([ + ("This is a positive example. I'm very happy today.", 1), + ("This is a negative sentence. Everything was wrong today at work.", 0) + ]) sentence = "Oh, that's great!" def test_bert(self): @@ -41,4 +43,4 @@ def _test_dump_and_load_model(self, model_name): if __name__ == '__main__': - unittest.main() \ No newline at end of file + unittest.main() diff --git a/test/load_csv.py b/test/load_csv.py index 4ab9ab6..2b12165 100644 --- a/test/load_csv.py +++ b/test/load_csv.py @@ -4,16 +4,30 @@ import unittest from ernie import SentenceClassifier, Models + class TestLoadCsv(unittest.TestCase): - classifier = SentenceClassifier(model_name=Models.BertBaseUncased, max_length=128, labels_no=2) - classifier.load_dataset(validation_split=0.2,csv_path="example.csv",read_csv_kwargs={"header":None}) - classifier.fine_tune(epochs=4, learning_rate=2e-5, training_batch_size=32, validation_batch_size=64) + classifier = SentenceClassifier( + model_name=Models.BertBaseUncased, + max_length=128, + labels_no=2 + ) + classifier.load_dataset( + validation_split=0.2, + csv_path="example.csv", + read_csv_kwargs={"header": None} + ) + classifier.fine_tune( + epochs=4, + learning_rate=2e-5, + training_batch_size=32, + validation_batch_size=64 + ) def test_predict(self): text = "Oh, that's great!" prediction = self.classifier.predict_one(text) - self.assertEqual(len(prediction),2) + self.assertEqual(len(prediction), 2) if __name__ == '__main__': - unittest.main() \ No newline at end of file + unittest.main() diff --git a/test/load_model.py b/test/load_model.py index 92990f9..19851d5 100644 --- a/test/load_model.py +++ b/test/load_model.py @@ -3,23 +3,28 @@ import unittest import pandas as pd -from ernie import SentenceClassifier, Models +from ernie import SentenceClassifier, Models # noqa: F401 + class TestLoadModel(unittest.TestCase): - tuples = [("This is a negative sentence. Everything was wrong today at work.", 0), - ("This is a positive example. I'm very happy today.", 1), - ("This is a neutral sentence. That's normal.", 2)] + tuples = [ + ("This is a negative sentence. Everything was wrong today.", 0), + ("This is a positive example. I'm very happy today.", 1), + ("This is a neutral sentence. That's normal.", 2) + ] df = pd.DataFrame(tuples) - classifier = SentenceClassifier(model_name='xlm-roberta-large', max_length=128, labels_no=3) + classifier = SentenceClassifier( + model_name='xlm-roberta-large', max_length=128, labels_no=3) classifier.load_dataset(df, validation_split=0.2) - classifier.fine_tune(epochs=4, learning_rate=2e-5, training_batch_size=32, validation_batch_size=64) + classifier.fine_tune(epochs=4, learning_rate=2e-5, + training_batch_size=32, validation_batch_size=64) def test_predict(self): text = "Oh, that's great!" prediction = self.classifier.predict_one(text) - self.assertEqual(len(prediction),3) + self.assertEqual(len(prediction), 3) + if __name__ == '__main__': unittest.main() - diff --git a/test/predict.py b/test/predict.py index 1e14a38..b8b2034 100644 --- a/test/predict.py +++ b/test/predict.py @@ -6,18 +6,23 @@ class TestPredict(unittest.TestCase): - classifier = SentenceClassifier(model_name=Models.BertBaseUncased, max_length=128, labels_no=2) + classifier = SentenceClassifier( + model_name=Models.BertBaseUncased, + max_length=128, + labels_no=2 + ) def test_batch_predict(self): sentences_no = 50 - predictions = list(self.classifier.predict(["this is a test " * 100] * sentences_no)) + predictions = list(self.classifier.predict( + ["this is a test " * 100] * sentences_no) + ) self.assertEqual(len(predictions), sentences_no) def test_predict_one(self): - sentences_no = 50 prediction = self.classifier.predict_one("this is a test " * 100) self.assertEqual(len(prediction), 2) if __name__ == '__main__': - unittest.main() \ No newline at end of file + unittest.main() diff --git a/test/split_aggregate.py b/test/split_aggregate.py index c2955e3..db7a01c 100644 --- a/test/split_aggregate.py +++ b/test/split_aggregate.py @@ -2,8 +2,13 @@ # -*- coding: utf-8 -*- import unittest -import pandas as pd -from ernie import SentenceClassifier, Models, AggregationStrategy, SplitStrategy, RegexExpressions +from ernie import ( + SentenceClassifier, + Models, + AggregationStrategy, + SplitStrategy, + RegexExpressions +) from statistics import mean import logging @@ -18,7 +23,8 @@ def round_tuple_of_floats(tuple_): class TestSplitAggregate(unittest.TestCase): logging.disable(logging.WARNING) - classifier = SentenceClassifier(model_name=Models.BertBaseUncased, max_length=128) + classifier = SentenceClassifier( + model_name=Models.BertBaseUncased, max_length=128) def test_aggregate_two_classes(self): softmax_tuples = ((0.2, 0.8), (0.8, 0.2), (0.5, 0.5)) @@ -27,11 +33,13 @@ def test_aggregate_two_classes(self): aggregated_tuple = strategy.aggregate(softmax_tuples) self.assertEqual(aggregated_tuple, (1 / 2, 1 / 2)) - strategy = AggregationStrategy(method=mean, max_items=2, sorting_class_index=1, top_items=True) + strategy = AggregationStrategy( + method=mean, max_items=2, sorting_class_index=1, top_items=True) aggregated_tuple = strategy.aggregate(softmax_tuples) self.assertEqual(aggregated_tuple, (0.35, 0.65)) - strategy = AggregationStrategy(method=mean, max_items=2, sorting_class_index=1, top_items=False) + strategy = AggregationStrategy( + method=mean, max_items=2, sorting_class_index=1, top_items=False) aggregated_tuple = strategy.aggregate(softmax_tuples) self.assertEqual(aggregated_tuple, (0.65, 0.35)) @@ -42,138 +50,160 @@ def test_aggregate_three_classes(self): aggregated_tuple = strategy.aggregate(softmax_tuples) self.assertEqual(aggregated_tuple, (1 / 3, 1 / 3, 1 / 3)) - strategy = AggregationStrategy(method=mean, max_items=2, sorting_class_index=0, top_items=True) - self.assertEqual(round_tuple_of_floats(strategy.aggregate(softmax_tuples)), (0.45, 0.4, 0.15)) + strategy = AggregationStrategy( + method=mean, max_items=2, sorting_class_index=0, top_items=True) + self.assertEqual(round_tuple_of_floats( + strategy.aggregate(softmax_tuples)), (0.45, 0.4, 0.15)) - strategy = AggregationStrategy(method=mean, max_items=2, sorting_class_index=0, top_items=False) - self.assertEqual(round_tuple_of_floats(strategy.aggregate(softmax_tuples)), (0.15, 0.45, 0.4)) + strategy = AggregationStrategy( + method=mean, max_items=2, sorting_class_index=0, top_items=False) + self.assertEqual(round_tuple_of_floats( + strategy.aggregate(softmax_tuples)), (0.15, 0.45, 0.4)) def test_split_groups(self): - splitter = SplitStrategy(split_patterns=[ - RegexExpressions.split_by_dot, RegexExpressions.split_by_semicolon, RegexExpressions.split_by_colon, - RegexExpressions.split_by_comma - ], - remove_patterns=[RegexExpressions.url, RegexExpressions.domain], - remove_too_short_groups=False, - group_splits=True) + splitter = SplitStrategy( + split_patterns=[ + RegexExpressions.split_by_dot, + RegexExpressions.split_by_semicolon, + RegexExpressions.split_by_colon, + RegexExpressions.split_by_comma + ], + remove_patterns=[RegexExpressions.url, RegexExpressions.domain], + remove_too_short_groups=False, + group_splits=True + ) # 256 tokens + 2 special tokens => no action (single sentence) - sentence = "0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255" + sentence = "0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255" # noqa: E501 expected_sentences = [ - "0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255" + "0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255" # noqa: E501 ] sentences = splitter.split(sentence, self.classifier.tokenizer) self.assertEqual(sentences, expected_sentences) # 128 tokens + 2 special tokens => 2 tokens exceeded in second group - sentence = "0. 1. 2. 3. 4. 5. 6. 7. 8. 9. 10. 11. 12. 13. 14. 15. 16. 17. 18. 19. 20. 21. 22. 23. 24. 25. 26. 27. 28. 29. 30. 31. 32. 33. 34. 35. 36. 37. 38. 39. 40. 41. 42. 43. 44. 45. 46. 47. 48. 49. 50. 51. 52. 53. 54. 55. 56. 57. 58. 59. 60. 61. 62. 63." + sentence = "0. 1. 2. 3. 4. 5. 6. 7. 8. 9. 10. 11. 12. 13. 14. 15. 16. 17. 18. 19. 20. 21. 22. 23. 24. 25. 26. 27. 28. 29. 30. 31. 32. 33. 34. 35. 36. 37. 38. 39. 40. 41. 42. 43. 44. 45. 46. 47. 48. 49. 50. 51. 52. 53. 54. 55. 56. 57. 58. 59. 60. 61. 62. 63." # noqa: E501 expected_sentences = [ - "0. 1. 2. 3. 4. 5. 6. 7. 8. 9. 10. 11. 12. 13. 14. 15. 16. 17. 18. 19. 20. 21. 22. 23. 24. 25. 26. 27. 28. 29. 30. 31. 32. 33. 34. 35. 36. 37. 38. 39. 40. 41. 42. 43. 44. 45. 46. 47. 48. 49. 50. 51. 52. 53. 54. 55. 56. 57. 58. 59. 60. 61. 62.", + "0. 1. 2. 3. 4. 5. 6. 7. 8. 9. 10. 11. 12. 13. 14. 15. 16. 17. 18. 19. 20. 21. 22. 23. 24. 25. 26. 27. 28. 29. 30. 31. 32. 33. 34. 35. 36. 37. 38. 39. 40. 41. 42. 43. 44. 45. 46. 47. 48. 49. 50. 51. 52. 53. 54. 55. 56. 57. 58. 59. 60. 61. 62.", # noqa: E501 "63." ] sentences = splitter.split(sentence, self.classifier.tokenizer) self.assertEqual(sentences, expected_sentences) # 128 tokens + 2 special tokens => 2 tokens exceeded in second group - sentence = "0; 1; 2; 3; 4; 5; 6; 7; 8; 9; 10; 11; 12; 13; 14; 15; 16; 17; 18; 19; 20; 21; 22; 23; 24; 25; 26; 27; 28; 29; 30; 31; 32; 33; 34; 35; 36; 37; 38; 39; 40; 41; 42; 43; 44; 45; 46; 47; 48; 49; 50; 51; 52; 53; 54; 55; 56; 57; 58; 59; 60; 61; 62; 63;" + sentence = "0; 1; 2; 3; 4; 5; 6; 7; 8; 9; 10; 11; 12; 13; 14; 15; 16; 17; 18; 19; 20; 21; 22; 23; 24; 25; 26; 27; 28; 29; 30; 31; 32; 33; 34; 35; 36; 37; 38; 39; 40; 41; 42; 43; 44; 45; 46; 47; 48; 49; 50; 51; 52; 53; 54; 55; 56; 57; 58; 59; 60; 61; 62; 63;" # noqa: E501 expected_sentences = [ - "0; 1; 2; 3; 4; 5; 6; 7; 8; 9; 10; 11; 12; 13; 14; 15; 16; 17; 18; 19; 20; 21; 22; 23; 24; 25; 26; 27; 28; 29; 30; 31; 32; 33; 34; 35; 36; 37; 38; 39; 40; 41; 42; 43; 44; 45; 46; 47; 48; 49; 50; 51; 52; 53; 54; 55; 56; 57; 58; 59; 60; 61; 62;", + "0; 1; 2; 3; 4; 5; 6; 7; 8; 9; 10; 11; 12; 13; 14; 15; 16; 17; 18; 19; 20; 21; 22; 23; 24; 25; 26; 27; 28; 29; 30; 31; 32; 33; 34; 35; 36; 37; 38; 39; 40; 41; 42; 43; 44; 45; 46; 47; 48; 49; 50; 51; 52; 53; 54; 55; 56; 57; 58; 59; 60; 61; 62;", # noqa: E501 "63;" ] sentences = splitter.split(sentence, self.classifier.tokenizer) self.assertEqual(sentences, expected_sentences) # 128 tokens + 2 special tokens => 2 tokens exceeded in second group - sentence = "0: 1: 2: 3: 4: 5: 6: 7: 8: 9: 10: 11: 12: 13: 14: 15: 16: 17: 18: 19: 20: 21: 22: 23: 24: 25: 26: 27: 28: 29: 30: 31: 32: 33: 34: 35: 36: 37: 38: 39: 40: 41: 42: 43: 44: 45: 46: 47: 48: 49: 50: 51: 52: 53: 54: 55: 56: 57: 58: 59: 60: 61: 62: 63: " + sentence = "0: 1: 2: 3: 4: 5: 6: 7: 8: 9: 10: 11: 12: 13: 14: 15: 16: 17: 18: 19: 20: 21: 22: 23: 24: 25: 26: 27: 28: 29: 30: 31: 32: 33: 34: 35: 36: 37: 38: 39: 40: 41: 42: 43: 44: 45: 46: 47: 48: 49: 50: 51: 52: 53: 54: 55: 56: 57: 58: 59: 60: 61: 62: 63: " # noqa: E501 expected_sentences = [ - "0: 1: 2: 3: 4: 5: 6: 7: 8: 9: 10: 11: 12: 13: 14: 15: 16: 17: 18: 19: 20: 21: 22: 23: 24: 25: 26: 27: 28: 29: 30: 31: 32: 33: 34: 35: 36: 37: 38: 39: 40: 41: 42: 43: 44: 45: 46: 47: 48: 49: 50: 51: 52: 53: 54: 55: 56: 57: 58: 59: 60: 61: 62:", + "0: 1: 2: 3: 4: 5: 6: 7: 8: 9: 10: 11: 12: 13: 14: 15: 16: 17: 18: 19: 20: 21: 22: 23: 24: 25: 26: 27: 28: 29: 30: 31: 32: 33: 34: 35: 36: 37: 38: 39: 40: 41: 42: 43: 44: 45: 46: 47: 48: 49: 50: 51: 52: 53: 54: 55: 56: 57: 58: 59: 60: 61: 62:", # noqa: E501 "63:" ] sentences = splitter.split(sentence, self.classifier.tokenizer) self.assertEqual(sentences, expected_sentences) # 128 tokens + 2 special tokens => 2 tokens exceeded in second group - sentence = "0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, " + sentence = "0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, " # noqa: E501 expected_sentences = [ - "0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62,", + "0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62,", # noqa: E501 "63," ] sentences = splitter.split(sentence, self.classifier.tokenizer) self.assertEqual(sentences, expected_sentences) # 128 tokens + 2 special tokens => two groups splitted by the comma - sentence = "0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63, 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127" + sentence = "0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63, 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127" # noqa: E501 expected_sentences = [ - "0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63,", - "64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127" + "0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63,", # noqa: E501 + "64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127" # noqa: E501 ] sentences = splitter.split(sentence, self.classifier.tokenizer) self.assertEqual(sentences, expected_sentences) - # 128 tokens + 2 special tokens => two groups splitted by the period and not by the comma - sentence = "0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15. 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63, 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127" + # 128 tokens + 2 special tokens => + # two groups splitted by the period and not by the comma + sentence = "0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15. 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63, 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127" # noqa: E501 expected_sentences = [ "0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15.", - "16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63, 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127", + "16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63, 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127", # noqa: E501 ] sentences = splitter.split(sentence, self.classifier.tokenizer) self.assertEqual(sentences, expected_sentences) def test_split_sentences(self): - splitter = SplitStrategy(split_patterns=[ - RegexExpressions.split_by_dot, RegexExpressions.split_by_semicolon, RegexExpressions.split_by_colon, - RegexExpressions.split_by_comma - ], - remove_patterns=[RegexExpressions.url, RegexExpressions.domain], - remove_too_short_groups=False, - group_splits=False) - - # 128 tokens + 2 special tokens => two sentences splitted by the period and not by the comma - sentence = "0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15. 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63, 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127" + splitter = SplitStrategy( + split_patterns=[ + RegexExpressions.split_by_dot, + RegexExpressions.split_by_semicolon, + RegexExpressions.split_by_colon, + RegexExpressions.split_by_comma + ], + remove_patterns=[RegexExpressions.url, RegexExpressions.domain], + remove_too_short_groups=False, + group_splits=False + ) + + # 128 tokens + 2 special tokens => + # two sentences splitted by the period and not by the comma + sentence = "0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15. 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63, 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127" # noqa: E501 expected_sentences = [ "0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15.", - "16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63, 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127", + "16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63, 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127", # noqa: E501 ] sentences = splitter.split(sentence, self.classifier.tokenizer) self.assertEqual(sentences, expected_sentences) - # 128 tokens + 2 special tokens => two sentences splitted by the period and not by the comma - sentence = "0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15, 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63. 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127" + # 128 tokens + 2 special tokens => + # two sentences splitted by the period and not by the comma + sentence = "0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15, 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63. 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127" # noqa: E501 expected_sentences = [ - "0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15, 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63.", - "64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127", + "0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15, 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63.", # noqa: E501 + "64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127", # noqa: E501 ] sentences = splitter.split(sentence, self.classifier.tokenizer) self.assertEqual(sentences, expected_sentences) - # 256 tokens + 2 special tokens => three sentences: split first with the period, then with the comma - sentence = "0 1 2 3 4 5 6, 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127. 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255" + # 256 tokens + 2 special tokens => + # three sentences: split first with the period, then with the comma + sentence = "0 1 2 3 4 5 6, 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127. 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255" # noqa: E501 expected_sentences = [ "0 1 2 3 4 5 6,", - "7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127.", - "128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255" + "7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127.", # noqa: E501 + "128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255" # noqa: E501 ] sentences = splitter.split(sentence, self.classifier.tokenizer) self.assertEqual(sentences, expected_sentences) def test_split_groups_remove_too_short(self): - splitter = SplitStrategy(split_patterns=[ - RegexExpressions.split_by_dot, RegexExpressions.split_by_semicolon, RegexExpressions.split_by_colon, - RegexExpressions.split_by_comma - ], - remove_patterns=[RegexExpressions.url, RegexExpressions.domain], - remove_too_short_groups=True, - group_splits=True) - - # 256 tokens + 2 special tokens => three sentences: split first with the period, then with the comma; remove first group - sentence = "0 1 2 3 4 5 6, 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127. 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255" + splitter = SplitStrategy( + split_patterns=[ + RegexExpressions.split_by_dot, + RegexExpressions.split_by_semicolon, + RegexExpressions.split_by_colon, + RegexExpressions.split_by_comma + ], + remove_patterns=[RegexExpressions.url, RegexExpressions.domain], + remove_too_short_groups=True, + group_splits=True + ) + + # 256 tokens + 2 special tokens => + # three sentences: split first with the period, then with the comma; + # remove first group + sentence = "0 1 2 3 4 5 6, 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127. 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255" # noqa: E501 expected_sentences = [ - "7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127.", - "128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255" + "7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127.", # noqa: E501 + "128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255" # noqa: E501 ] sentences = splitter.split(sentence, self.classifier.tokenizer) self.assertEqual(sentences, expected_sentences) if __name__ == '__main__': - unittest.main() \ No newline at end of file + unittest.main() From 9308b2ce1e6725c235a10e64c763d671f9716548 Mon Sep 17 00:00:00 2001 From: brunneis Date: Wed, 22 Dec 2021 22:13:37 +0100 Subject: [PATCH 6/8] update readme --- README.md | 46 ++++++++++++++++++++++++++++++++++++---------- 1 file changed, 36 insertions(+), 10 deletions(-) diff --git a/README.md b/README.md index 2a99a94..47bf6b1 100644 --- a/README.md +++ b/README.md @@ -32,13 +32,24 @@ pip install ernie from ernie import SentenceClassifier, Models import pandas as pd -tuples = [("This is a positive example. I'm very happy today.", 1), - ("This is a negative sentence. Everything was wrong today at work.", 0)] +tuples = [ + ("This is a positive example. I'm very happy today.", 1), + ("This is a negative sentence. Everything was wrong today at work.", 0) +] df = pd.DataFrame(tuples) -classifier = SentenceClassifier(model_name=Models.BertBaseUncased, max_length=64, labels_no=2) +classifier = SentenceClassifier( + model_name=Models.BertBaseUncased, + max_length=64, + labels_no=2 +) classifier.load_dataset(df, validation_split=0.2) -classifier.fine_tune(epochs=4, learning_rate=2e-5, training_batch_size=32, validation_batch_size=64) +classifier.fine_tune( + epochs=4, + learning_rate=2e-5, + training_batch_size=32, + validation_batch_size=64 +) ``` # Prediction @@ -76,9 +87,11 @@ If the length in tokens of the texts is greater than the `max_length` with which from ernie import SplitStrategies, AggregationStrategies texts = ["Oh, that's great!", "That's really bad"] -probabilities = classifier.predict(texts, - split_strategy=SplitStrategies.GroupedSentencesWithoutUrls, - aggregation_strategy=AggregationStrategies.Mean) +probabilities = classifier.predict( + texts, + split_strategy=SplitStrategies.GroupedSentencesWithoutUrls, + aggregation_strategy=AggregationStrategies.Mean +) ``` @@ -86,8 +99,18 @@ You can define your custom strategies through `AggregationStrategy` and `SplitSt ```python from ernie import SplitStrategy, AggregationStrategy -my_split_strategy = SplitStrategy(split_patterns: list, remove_patterns: list, remove_too_short_groups: bool, group_splits: bool) -my_aggregation_strategy = AggregationStrategy(method: function, max_items: int, top_items: bool, sorting_class_index: int) +my_split_strategy = SplitStrategy( + split_patterns: list, + remove_patterns: list, + remove_too_short_groups: bool, + group_splits: bool +) +my_aggregation_strategy = AggregationStrategy( + method: function, + max_items: int, + top_items: bool, + sorting_class_index: int +) ``` # Save and restore a fine-tuned model @@ -105,7 +128,10 @@ classifier = SentenceClassifier(model_path='./model') Since the execution may break during training (especially if you are using Google Colab), you can opt to secure every new trained epoch, so the training can be resumed without losing all the progress. ```python -classifier = SentenceClassifier(model_name=Models.BertBaseUncased, max_length=64) +classifier = SentenceClassifier( + model_name=Models.BertBaseUncased, + max_length=64 +) classifier.load_dataset(df, validation_split=0.2) for epoch in range(1, 5): From c90149a1918998da38191758865fc8397234c346 Mon Sep 17 00:00:00 2001 From: brunneis Date: Wed, 22 Dec 2021 22:15:58 +0100 Subject: [PATCH 7/8] update readme --- README.md | 62 +++++++++++++++++++++++++++---------------------------- 1 file changed, 31 insertions(+), 31 deletions(-) diff --git a/README.md b/README.md index 47bf6b1..5265d16 100644 --- a/README.md +++ b/README.md @@ -1,14 +1,14 @@


- Ernie Logo + Ernie Logo

Downloads PyPi - - License + GitHub releases + License

@@ -33,22 +33,22 @@ from ernie import SentenceClassifier, Models import pandas as pd tuples = [ - ("This is a positive example. I'm very happy today.", 1), - ("This is a negative sentence. Everything was wrong today at work.", 0) + ("This is a positive example. I'm very happy today.", 1), + ("This is a negative sentence. Everything was wrong today at work.", 0) ] df = pd.DataFrame(tuples) classifier = SentenceClassifier( - model_name=Models.BertBaseUncased, - max_length=64, - labels_no=2 + model_name=Models.BertBaseUncased, + max_length=64, + labels_no=2 ) classifier.load_dataset(df, validation_split=0.2) classifier.fine_tune( - epochs=4, - learning_rate=2e-5, - training_batch_size=32, - validation_batch_size=64 + epochs=4, + learning_rate=2e-5, + training_batch_size=32, + validation_batch_size=64 ) ``` @@ -88,9 +88,9 @@ from ernie import SplitStrategies, AggregationStrategies texts = ["Oh, that's great!", "That's really bad"] probabilities = classifier.predict( - texts, - split_strategy=SplitStrategies.GroupedSentencesWithoutUrls, - aggregation_strategy=AggregationStrategies.Mean + texts, + split_strategy=SplitStrategies.GroupedSentencesWithoutUrls, + aggregation_strategy=AggregationStrategies.Mean ) ``` @@ -100,16 +100,16 @@ You can define your custom strategies through `AggregationStrategy` and `SplitSt from ernie import SplitStrategy, AggregationStrategy my_split_strategy = SplitStrategy( - split_patterns: list, - remove_patterns: list, - remove_too_short_groups: bool, - group_splits: bool + split_patterns: list, + remove_patterns: list, + remove_too_short_groups: bool, + group_splits: bool ) my_aggregation_strategy = AggregationStrategy( - method: function, - max_items: int, - top_items: bool, - sorting_class_index: int + method: function, + max_items: int, + top_items: bool, + sorting_class_index: int ) ``` @@ -129,17 +129,17 @@ Since the execution may break during training (especially if you are using Googl ```python classifier = SentenceClassifier( - model_name=Models.BertBaseUncased, - max_length=64 + model_name=Models.BertBaseUncased, + max_length=64 ) classifier.load_dataset(df, validation_split=0.2) for epoch in range(1, 5): - if epoch == 3: - raise Exception("Forced crash") + if epoch == 3: + raise Exception("Forced crash") - classifier.fine_tune(epochs=1) - classifier.dump(f'./my-model/{epoch}') + classifier.fine_tune(epochs=1) + classifier.dump(f'./my-model/{epoch}') ``` ```python @@ -149,8 +149,8 @@ classifier = SentenceClassifier(model_path=f'./my-model/{last_training_epoch}') classifier.load_dataset(df, validation_split=0.2) for epoch in range(last_training_epoch + 1, 5): - classifier.fine_tune(epochs=1) - classifier.dump(f'./my-model/{epoch}') + classifier.fine_tune(epochs=1) + classifier.dump(f'./my-model/{epoch}') ``` # Autosave From 219f2e11ee9d47219b20572b180ed2f5a2f11f83 Mon Sep 17 00:00:00 2001 From: brunneis Date: Wed, 22 Dec 2021 22:17:54 +0100 Subject: [PATCH 8/8] update readme --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 5265d16..583c586 100644 --- a/README.md +++ b/README.md @@ -36,8 +36,8 @@ tuples = [ ("This is a positive example. I'm very happy today.", 1), ("This is a negative sentence. Everything was wrong today at work.", 0) ] - df = pd.DataFrame(tuples) + classifier = SentenceClassifier( model_name=Models.BertBaseUncased, max_length=64,