Skip to content

Commit

Permalink
feat: add sem chunking, merge test files, isort and black
Browse files Browse the repository at this point in the history
  • Loading branch information
violenil committed Sep 20, 2024
1 parent b1d7b54 commit 15165e4
Show file tree
Hide file tree
Showing 5 changed files with 141 additions and 216 deletions.
86 changes: 39 additions & 47 deletions chunked_pooling/chunking.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,25 +19,44 @@ def __init__(
chunking_strategy: str,
):
if chunking_strategy not in CHUNKING_STRATEGIES:
raise ValueError("Unsupported chunking strategy")
raise ValueError("Unsupported chunking strategy: ", chunking_strategy)
self.chunking_strategy = chunking_strategy
self.embed_model = None
self.embedding_model_name = None

def _setup_semantic_chunking(self, embedding_model_name):
if embedding_model_name:
self.embedding_model_name = embedding_model_name

self.embed_model = HuggingFaceEmbedding(
model_name=self.embedding_model_name,
max_length=512,
trust_remote_code=True,
)
self.splitter = SemanticSplitterNodeParser(
embed_model=self.embed_model,
show_progress=False,
)

def chunk_semantically(
self, text: str, min_tokens: Optional[int] = None
) -> List[Tuple[int, int, int]]:
self,
text: str,
tokenizer: 'AutoTokenizer',
embedding_model_name: Optional[str] = None,
) -> List[Tuple[int, int]]:
if self.embed_model is None:
setup_semantic_chunking()

min_tokens = min_tokens or self.min_tokens
self._setup_semantic_chunking(embedding_model_name)

# Get semantic nodes
nodes = [
(node.start_char_idx, node.end_char_idx)
for node in self.splitter.get_nodes_from_documents(
[Document(text=text)], show_progress=False
)
]

# Tokenize the entire text
tokens = self.tokenizer.encode_plus(
tokens = tokenizer.encode_plus(
text,
return_offsets_mapping=True,
add_special_tokens=False,
Expand All @@ -49,16 +68,8 @@ def chunk_semantically(

chunk_spans = []

if len(token_offsets) < min_tokens:
# If the entire text has fewer than 10 tokens, return it as a single chunk
chunk_spans.append((0, len(token_offsets) - 1))
return chunk_spans

i = 0
while i < len(nodes):
char_start, char_end = nodes[i]

# convert char_start and char_end to token indices
for char_start, char_end in nodes:
# Convert char indices to token indices
start_chunk_index = bisect.bisect_left(
[offset[0] for offset in token_offsets], char_start
)
Expand All @@ -67,39 +78,14 @@ def chunk_semantically(
- 1
)

# Ensure each chunk has at least min_tokens tokens
while (
end_chunk_index - start_chunk_index + 1 < min_tokens
and i < len(nodes) - 1
):
# Merge with the next node
i += 1
char_end = nodes[i][1]
end_chunk_index = (
bisect.bisect_right(
[offset[1] for offset in token_offsets], char_end
)
- 1
)

# If the chunk is still less than min_tokens and it's the last node, handle it explicitly
if (
end_chunk_index - start_chunk_index + 1 < min_tokens
and i == len(nodes) - 1
):
end_chunk_index = min(
start_chunk_index + min_tokens - 1, len(token_offsets) - 1
)

# If the chunk is outside of the tokenized text, break out of loop
if start_chunk_index >= len(token_offsets) or end_chunk_index >= len(
# Add the chunk span if it's within the tokenized text
if start_chunk_index < len(token_offsets) and end_chunk_index < len(
token_offsets
):
chunk_spans.append((start_chunk_index, end_chunk_index))
else:
break

chunk_spans.append((start_chunk_index, end_chunk_index))
i += 1

return chunk_spans

def chunk_by_tokens(
Expand Down Expand Up @@ -156,9 +142,15 @@ def chunk(
chunking_strategy: str = None,
chunk_size: Optional[int] = None,
n_sentences: Optional[int] = None,
embedding_model_name: Optional[str] = None,
):
chunking_strategy = chunking_strategy or self.chunking_strategy
if chunking_strategy == "semantic":
return self.chunk_semantically(text)
return self.chunk_semantically(
text,
embedding_model_name=embedding_model_name,
tokenizer=tokenizer,
)
elif chunking_strategy == "fixed":
if chunk_size < 10:
raise ValueError("Chunk size must be greater than 10.")
Expand Down
12 changes: 4 additions & 8 deletions chunked_pooling/mteb_chunked_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
import torch
from mteb.abstasks import AbsTask
from mteb.evaluation.evaluators import RetrievalEvaluator
from mteb.load_results.mteb_results import ScoresDict
from mteb.tasks import Retrieval
from tqdm import tqdm
from mteb.load_results.mteb_results import ScoresDict

from chunked_pooling import chunked_pooling
from chunked_pooling.chunking import Chunker
Expand Down Expand Up @@ -57,11 +57,7 @@ def calculate_metadata_metrics(self):
self.retrieval_task.calculate_metadata_metrics()

def evaluate(
self,
model,
split: str = "test",
encode_kwargs: dict[str, Any] = {},
**kwargs
self, model, split: str = "test", encode_kwargs: dict[str, Any] = {}, **kwargs
) -> dict[str, ScoresDict]:
scores: dict[str, ScoresDict] = {}
hf_subsets = list(self.hf_subsets) if self.is_multilingual else ["default"]
Expand All @@ -81,11 +77,11 @@ def evaluate(
self.queries[hf_subset][split],
self.relevant_docs[hf_subset][split],
)

scores[hf_subset] = self._evaluate_monolingual(
model, corpus, queries, relevant_docs, hf_subset, **kwargs
)

return scores

def _evaluate_monolingual(
Expand Down
22 changes: 10 additions & 12 deletions run_chunked_eval.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,25 @@
import click
import torch.cuda

from transformers import AutoModel, AutoTokenizer
from mteb import MTEB
from transformers import AutoModel, AutoTokenizer

from chunked_pooling.chunked_eval_tasks import (
SciFactChunked,
TRECCOVIDChunked,
FiQA2018Chunked,
NFCorpusChunked,
QuoraChunked,
LEMBWikimQARetrievalChunked,
)
from chunked_pooling.chunked_eval_tasks import (FiQA2018Chunked,
LEMBWikimQARetrievalChunked,
NFCorpusChunked, QuoraChunked,
SciFactChunked,
TRECCOVIDChunked)

DEFAULT_CHUNKING_STRATEGY = 'fixed'
DEFAULT_CHUNK_SIZE = 256
DEFAULT_N_SENTENCES = 5


def remove_prompt_name(original_encode):
def wrapper(self, *args, **kwargs):
# Remove 'prompt_name' from kwargs if present
kwargs.pop('prompt_name', None)
return original_encode(self, *args, **kwargs)

return wrapper


Expand Down Expand Up @@ -84,7 +82,7 @@ def main(model_name, strategy, task_name):
output_folder='results-chunked-pooling',
eval_splits=['test'],
overwrite_results=True,
encode_kwargs = {'batch_size': 1},
encode_kwargs={'batch_size': 1},
)

tasks = [
Expand All @@ -108,7 +106,7 @@ def main(model_name, strategy, task_name):
output_folder='results-normal-pooling',
eval_splits=['test'],
overwrite_results=True,
encode_kwargs = {'batch_size': 1},
encode_kwargs={'batch_size': 1},
)


Expand Down
138 changes: 0 additions & 138 deletions tests/test_chunking.py

This file was deleted.

Loading

0 comments on commit 15165e4

Please sign in to comment.