Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tests for chunking module #4

Merged
merged 9 commits into from
Sep 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
name: Run Tests

on:
push:
branches: [ main ]
pull_request:
branches: [ main ]

jobs:
test:
runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v3

- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.11'

- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install .[dev]

- name: Run tests
run: pytest tests
75 changes: 70 additions & 5 deletions chunked_pooling/chunking.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,72 @@ def __init__(
chunking_strategy: str,
):
if chunking_strategy not in CHUNKING_STRATEGIES:
raise ValueError("Unsupported chunking strategy")
raise ValueError("Unsupported chunking strategy: ", chunking_strategy)
self.chunking_strategy = chunking_strategy
self.embed_model = None
self.embedding_model_name = None

def _setup_semantic_chunking(self, embedding_model_name):
if embedding_model_name:
self.embedding_model_name = embedding_model_name

self.embed_model = HuggingFaceEmbedding(
model_name=self.embedding_model_name,
trust_remote_code=True,
)
self.splitter = SemanticSplitterNodeParser(
embed_model=self.embed_model,
show_progress=False,
)

def chunk_semantically(
self, text: str, min_tokens: Optional[int] = None
) -> List[Tuple[int, int, int]]:
raise NotImplementedError('Semantic Chunking is not supported at the moment')
self,
text: str,
tokenizer: 'AutoTokenizer',
embedding_model_name: Optional[str] = None,
) -> List[Tuple[int, int]]:
if self.embed_model is None:
self._setup_semantic_chunking(embedding_model_name)

# Get semantic nodes
nodes = [
(node.start_char_idx, node.end_char_idx)
for node in self.splitter.get_nodes_from_documents(
[Document(text=text)], show_progress=False
)
]

# Tokenize the entire text
tokens = tokenizer.encode_plus(
text,
return_offsets_mapping=True,
add_special_tokens=False,
padding=True,
truncation=True,
)
token_offsets = tokens.offset_mapping

chunk_spans = []

for char_start, char_end in nodes:
# Convert char indices to token indices
start_chunk_index = bisect.bisect_left(
[offset[0] for offset in token_offsets], char_start
)
end_chunk_index = (
bisect.bisect_right([offset[1] for offset in token_offsets], char_end)
- 1
)

# Add the chunk span if it's within the tokenized text
if start_chunk_index < len(token_offsets) and end_chunk_index < len(
token_offsets
):
chunk_spans.append((start_chunk_index, end_chunk_index))
else:
break

return chunk_spans

def chunk_by_tokens(
self,
Expand Down Expand Up @@ -81,9 +140,15 @@ def chunk(
chunking_strategy: str = None,
chunk_size: Optional[int] = None,
n_sentences: Optional[int] = None,
embedding_model_name: Optional[str] = None,
):
chunking_strategy = chunking_strategy or self.chunking_strategy
if chunking_strategy == "semantic":
return self.chunk_semantically(text)
return self.chunk_semantically(
text,
embedding_model_name=embedding_model_name,
tokenizer=tokenizer,
)
elif chunking_strategy == "fixed":
if chunk_size < 10:
raise ValueError("Chunk size must be greater than 10.")
Expand Down
2 changes: 1 addition & 1 deletion chunked_pooling/mteb_chunked_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
import torch
from mteb.abstasks import AbsTask
from mteb.evaluation.evaluators import RetrievalEvaluator
from mteb.load_results.mteb_results import ScoresDict
from mteb.tasks import Retrieval
from tqdm import tqdm
from mteb.load_results.mteb_results import ScoresDict

from chunked_pooling import chunked_pooling
from chunked_pooling.chunking import Chunker
Expand Down
4 changes: 2 additions & 2 deletions chunked_pooling/wrappers.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from typing import List, Optional, Union

import torch
import torch.nn as nn
from transformers import AutoModel

from typing import List, Union, Optional


class JinaEmbeddingsV3Wrapper(nn.Module):
def __init__(
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ dependencies = [
"llama-index-embeddings-huggingface==0.3.1",
"llama-index==0.11.10",
"click==8.1.7",
"einops==0.6.1"
"einops==0.6.1",
]
version = "0.0.0"

Expand Down
7 changes: 1 addition & 6 deletions run_chunked_eval.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,14 @@
import click
import torch.cuda

from transformers import AutoModel, AutoTokenizer
from mteb import MTEB
from transformers import AutoModel, AutoTokenizer

from chunked_pooling.chunked_eval_tasks import *

from chunked_pooling.wrappers import load_model

from chunked_pooling.wrappers import load_model

DEFAULT_CHUNKING_STRATEGY = 'fixed'
DEFAULT_CHUNK_SIZE = 256
DEFAULT_N_SENTENCES = 5

BATCH_SIZE = 1


Expand Down
1 change: 0 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import pytest

from mteb.abstasks.TaskMetadata import TaskMetadata

from chunked_pooling.mteb_chunked_eval import AbsTaskChunkedRetrieval
Expand Down
99 changes: 88 additions & 11 deletions tests/test_chunking_methods.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import pytest
from transformers import AutoTokenizer

from transformers import AutoTokenizer, AutoModel

from chunked_pooling.chunking import Chunker
from chunked_pooling.chunking import CHUNKING_STRATEGIES, Chunker
from chunked_pooling.mteb_chunked_eval import AbsTaskChunkedRetrieval

EXAMPLE_TEXT_1 = "Berlin is the capital and largest city of Germany, both by area and by population. Its more than 3.85 million inhabitants make it the European Union's most populous city, as measured by population within city limits. The city is also one of the states of Germany, and is the third smallest state in the country in terms of area."
Expand All @@ -13,7 +12,7 @@
def test_chunk_by_sentences(n_sentences):
strategy = 'sentences'
model_name = 'jinaai/jina-embeddings-v2-small-en'
chunker = Chunker(strategy)
chunker = Chunker(chunking_strategy=strategy)
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
boundary_cues = chunker.chunk(
text=EXAMPLE_TEXT_1,
Expand All @@ -32,21 +31,21 @@ def test_chunk_by_sentences(n_sentences):
max_length=8192,
)

# check that the cues start with 0 and and with the last token
# check that the cues start with 0 and end with the last token
assert extended_boundary_cues[0][0] == 0
assert len(model_inputs.tokens(0)) == extended_boundary_cues[-1][1]
assert len(model_inputs.tokens()) == extended_boundary_cues[-1][1]

# check that all chunks but the last one end with a punctation
# check that all chunks but the last one end with a punctuation
assert all(
model_inputs.tokens(0)[x:y][-1] in PUNCTATIONS
model_inputs.tokens()[x:y][-1] in PUNCTATIONS
for (x, y) in extended_boundary_cues[:-1]
)

# check that the last chunk ends with a "[SEP]" token
last_cue = extended_boundary_cues[-1]
assert model_inputs.tokens(0)[last_cue[0] : last_cue[1]][-1] == "[SEP]"
assert model_inputs.tokens()[last_cue[0] : last_cue[1]][-1] == "[SEP]"

# check that the boundary cues are continues (no token is missing)
# check that the boundary cues are continuous (no token is missing)
assert all(
[
extended_boundary_cues[i][1] == extended_boundary_cues[i + 1][0]
Expand All @@ -66,7 +65,7 @@ def test_token_equivalence(boundary_cues):
)
for start_token_idx, end_token_idx in boundary_cues:
decoded_text_chunk = tokenizer.decode(
tokens.encodings[0].ids[start_token_idx:end_token_idx]
tokens.input_ids[start_token_idx:end_token_idx]
)

original_text_chunk = EXAMPLE_TEXT_1[
Expand All @@ -77,3 +76,81 @@ def test_token_equivalence(boundary_cues):
chunk_tokens_original = tokenizer.encode_plus(original_text_chunk)
chunk_tokens_decoded = tokenizer.encode_plus(decoded_text_chunk)
assert chunk_tokens_original == chunk_tokens_decoded


def test_chunker_initialization():
for strategy in CHUNKING_STRATEGIES:
chunker = Chunker(chunking_strategy=strategy)
assert chunker.chunking_strategy == strategy


def test_invalid_chunking_strategy():
with pytest.raises(ValueError):
Chunker(chunking_strategy="invalid")


def test_chunk_by_tokens():
chunker = Chunker(chunking_strategy="fixed")
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
chunks = chunker.chunk(EXAMPLE_TEXT_1, tokenizer=tokenizer, chunk_size=10)
assert len(chunks) > 1
for start, end in chunks:
assert end - start <= 10


def test_chunk_semantically():
chunker = Chunker(chunking_strategy="semantic")
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
chunks = chunker.chunk(
EXAMPLE_TEXT_1,
tokenizer=tokenizer,
chunking_strategy='semantic',
embedding_model_name='jinaai/jina-embeddings-v2-small-en',
)
assert len(chunks) > 0


def test_empty_input():
chunker = Chunker(chunking_strategy="fixed")
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
chunks = chunker.chunk("", tokenizer=tokenizer, chunk_size=10)
assert len(chunks) == 0


def test_input_shorter_than_chunk_size():
short_text = "Short text."
chunker = Chunker(chunking_strategy="fixed")
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
chunks = chunker.chunk(short_text, tokenizer=tokenizer, chunk_size=20)
assert len(chunks) == 1


@pytest.mark.parametrize("chunk_size", [10, 20, 50])
def test_various_chunk_sizes(chunk_size):
chunker = Chunker(chunking_strategy="fixed")
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
chunks = chunker.chunk(EXAMPLE_TEXT_1, tokenizer=tokenizer, chunk_size=chunk_size)
assert len(chunks) > 0
for start, end in chunks:
assert end - start <= chunk_size


def test_chunk_method_with_different_strategies():
chunker = Chunker(chunking_strategy="fixed")
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
fixed_chunks = chunker.chunk(EXAMPLE_TEXT_1, tokenizer=tokenizer, chunk_size=10)
semantic_chunks = chunker.chunk(
EXAMPLE_TEXT_1,
tokenizer=tokenizer,
chunking_strategy="semantic",
embedding_model_name='jinaai/jina-embeddings-v2-small-en',
)
assert fixed_chunks != semantic_chunks


def test_chunk_by_sentences_different_n():
chunker = Chunker(chunking_strategy="sentences")
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
chunks_1 = chunker.chunk(EXAMPLE_TEXT_1, tokenizer=tokenizer, n_sentences=1)
chunks_2 = chunker.chunk(EXAMPLE_TEXT_1, tokenizer=tokenizer, n_sentences=2)
assert len(chunks_1) > len(chunks_2)
2 changes: 1 addition & 1 deletion tests/test_v3.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from transformers import AutoTokenizer

from run_chunked_eval import load_model, DEFAULT_CHUNK_SIZE
from run_chunked_eval import DEFAULT_CHUNK_SIZE, load_model

MODEL_NAME = 'jinaai/jina-embeddings-v3'

Expand Down
Loading