Skip to content

Commit

Permalink
Merge branch 'main' into test-chunking
Browse files Browse the repository at this point in the history
  • Loading branch information
violenil authored Sep 20, 2024
2 parents d0261d1 + 3deb3d9 commit 57699cb
Show file tree
Hide file tree
Showing 7 changed files with 238 additions and 39 deletions.
2 changes: 1 addition & 1 deletion chunked_pooling/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def chunked_pooling(
if (end - start) >= 1
]
pooled_embeddings = [
embedding.detach().cpu().numpy() for embedding in pooled_embeddings
embedding.float().detach().cpu().numpy() for embedding in pooled_embeddings
]
outputs.append(pooled_embeddings)

Expand Down
75 changes: 54 additions & 21 deletions chunked_pooling/mteb_chunked_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,24 @@ def __init__(
prune_size: Optional[int] = None,
chunk_size: Optional[int] = None,
n_sentences: Optional[int] = None,
model_has_instructions: bool = False,
**kwargs,
):
super().__init__(**kwargs)
self.retrieval_task = getattr(
Retrieval,
self.metadata_dict['dataset'].get('name', None)
or self.metadata_dict.get('name'),
)()
try:
self.retrieval_task = getattr(
Retrieval,
self.metadata_dict['dataset'].get('name', None)
or self.metadata_dict.get('name'),
)()
except:
logger.warning('Could not initialize retrieval_task')
self.chunking_strategy = chunking_strategy
self.chunker = Chunker(self.chunking_strategy)
self.chunked_pooling_enabled = chunked_pooling_enabled
self.tokenizer = tokenizer
self.prune_size = prune_size
self.model_has_instructions = model_has_instructions
self.chunking_args = {
'chunk_size': chunk_size,
'n_sentences': n_sentences,
Expand Down Expand Up @@ -102,7 +107,10 @@ def _evaluate_monolingual(
else:
query_ids = list(queries.keys())
query_texts = [queries[k] for k in query_ids]
query_embs = model.encode(query_texts)
if hasattr(model, 'encode_queries'):
query_embs = model.encode_queries(query_texts)
else:
query_embs = model.encode(query_texts)

corpus_ids = list(corpus.keys())
corpus_texts = [
Expand All @@ -114,17 +122,7 @@ def _evaluate_monolingual(
for k in corpus_ids
]

chunk_annotations = [
self._extend_special_tokens(
self.chunker.chunk(
text,
self.tokenizer,
chunking_strategy=self.chunking_strategy,
**self.chunking_args,
)
)
for text in corpus_texts
]
chunk_annotations = self._calculate_annotations(model, corpus_texts)

corpus_embs = []
with torch.no_grad():
Expand All @@ -135,7 +133,11 @@ def _evaluate_monolingual(
),
total=(len(corpus_texts) // batch_size),
):
text_inputs = [x[0] for x in inputs]
if self.model_has_instructions:
instr = model.get_instructions()[1]
else:
instr = ''
text_inputs = [instr + x[0] for x in inputs]
annotations = [x[1] for x in inputs]
model_inputs = self.tokenizer(
text_inputs,
Expand Down Expand Up @@ -264,6 +266,27 @@ def _apply_chunking(self, corpus, tokenizer):
chunked_corpus[k] = current_doc
return chunked_corpus

def _calculate_annotations(self, model, corpus_texts):
if self.model_has_instructions:
instr = model.get_instructions()[1]
instr_tokens = self.tokenizer(instr, add_special_tokens=False)
n_instruction_tokens = len(instr_tokens[0])
else:
n_instruction_tokens = 0
chunk_annotations = [
self._extend_special_tokens(
self.chunker.chunk(
text,
self.tokenizer,
chunking_strategy=self.chunking_strategy,
**self.chunking_args,
),
n_instruction_tokens=n_instruction_tokens,
)
for text in corpus_texts
]
return chunk_annotations

@staticmethod
def _flatten_chunks(chunked_corpus):
flattened_corpus = dict()
Expand All @@ -283,16 +306,26 @@ def _batch_inputs(li, batch_size):
yield li[i : i + batch_size]

@staticmethod
def _extend_special_tokens(annotations):
def _extend_special_tokens(
annotations, n_instruction_tokens=0, include_prefix=True, include_sep=True
):
"""Extends the spans because of additional special tokens, e.g. the CLS token
which are not considered by the chunker.
"""
new_annotations = []
for i in range(len(annotations)):
left = annotations[i][0] + int(i > 0) # move everything by one for [CLS]
add_left_offset = 1 if (not include_prefix) or int(i > 0) else 0
left_offset = 1 + n_instruction_tokens
left = (
annotations[i][0] + add_left_offset * left_offset
) # move everything by one for [CLS]

add_sep = 1 if include_sep and ((i + 1) == len(annotations)) else 0
right_offset = left_offset + add_sep
right = (
annotations[i][1] + 1 + int((i + 1) == len(annotations))
annotations[i][1] + right_offset
) # move everything by one for [CLS] and the last one for [SEP]

new_annotations.append((left, right))
return new_annotations

Expand Down
101 changes: 101 additions & 0 deletions chunked_pooling/wrappers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import torch
import torch.nn as nn
from transformers import AutoModel

from typing import List, Union, Optional


class JinaEmbeddingsV3Wrapper(nn.Module):
def __init__(
self, model_name, tasks=['retrieval.query', 'retrieval.passage'], **model_kwargs
):
super().__init__()
self._model = AutoModel.from_pretrained(
model_name, trust_remote_code=True, **model_kwargs
)
self.tasks = tasks

def encode_queries(
self,
sentences: Union[str, List[str]],
*args,
task: Optional[str] = None,
**kwargs,
):
return self._model.encode(sentences, *args, task=self.tasks[0], **kwargs)

def encode_corpus(
self,
sentences: Union[str, List[str]],
*args,
**kwargs,
):
_sentences = [self._construct_document(sentence) for sentence in sentences]
return self._model.encode(_sentences, *args, task=self.tasks[1], **kwargs)

def get_instructions(self):
return [self._model._task_instructions[x] for x in self.tasks]

def forward(self, *args, **kwargs):
task_id = self._model._adaptation_map[self.tasks[1]]
num_examples = kwargs['input_ids'].shape[0]
adapter_mask = torch.full(
(num_examples,), task_id, dtype=torch.int32, device=self._model.device
)
return self._model.forward(*args, adapter_mask=adapter_mask, **kwargs)

def _construct_document(self, doc):
if isinstance(doc, str):
return doc
elif 'title' in doc:
return f'{doc["title"]} {doc["text"].strip()}'
else:
return doc['text'].strip()

@property
def device(self):
return self._model.device

@staticmethod
def has_instructions():
return True


MODEL_WRAPPERS = {'jinaai/jina-embeddings-v3': JinaEmbeddingsV3Wrapper}
MODELS_WITHOUT_PROMPT_NAME_ARG = [
'jinaai/jina-embeddings-v2-small-en',
'jinaai/jina-embeddings-v2-base-en',
'jinaai/jina-embeddings-v3',
]


def remove_unsupported_kwargs(original_encode):
def wrapper(self, *args, **kwargs):
# Remove 'prompt_name' from kwargs if present
kwargs.pop('prompt_name', None)
kwargs.pop('request_qid', None)
return original_encode(self, *args, **kwargs)

return wrapper


def load_model(model_name, **model_kwargs):
if model_name in MODEL_WRAPPERS:
model = MODEL_WRAPPERS[model_name](model_name, **model_kwargs)
has_instructions = MODEL_WRAPPERS[model_name].has_instructions()
else:
model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
has_instructions = False

# encode functions of various models do not support all sentence transformers kwargs parameter
if model_name in MODELS_WITHOUT_PROMPT_NAME_ARG:
ENCODE_FUNC_NAMES = ['encode', 'encode_queries', 'encode_corpus']
for func_name in ENCODE_FUNC_NAMES:
if hasattr(model, func_name):
setattr(
model,
func_name,
remove_unsupported_kwargs(getattr(model, func_name)),
)

return model, has_instructions
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
name = "late_chunking"
requires-python = "~=3.8"
dependencies = [
"jupyterlab==4.2.4",
"jupyterlab==4.2.5",
"transformers==4.43.4",
"torch==2.4.0",
"mteb==1.14.20",
Expand Down
27 changes: 11 additions & 16 deletions run_chunked_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,12 @@
SciFactChunked,
TRECCOVIDChunked)

from chunked_pooling.wrappers import load_model

DEFAULT_CHUNKING_STRATEGY = 'fixed'
DEFAULT_CHUNK_SIZE = 256
DEFAULT_N_SENTENCES = 5


def remove_prompt_name(original_encode):
def wrapper(self, *args, **kwargs):
# Remove 'prompt_name' from kwargs if present
kwargs.pop('prompt_name', None)
return original_encode(self, *args, **kwargs)

return wrapper
BATCH_SIZE = 1


@click.command()
Expand All @@ -43,24 +37,23 @@ def main(model_name, strategy, task_name):
except:
raise ValueError(f'Unknown task name: {task_name}')

model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
if model_name == 'jinaai/jina-embeddings-v2-small-en':
print("Overwriting encode")
model.encode = remove_prompt_name(model.encode)
model, has_instructions = load_model(model_name)

tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)

chunking_args = {
'chunk_size': DEFAULT_CHUNK_SIZE,
'n_sentences': DEFAULT_N_SENTENCES,
'chunking_strategy': strategy,
'model_has_instructions': has_instructions,
}

if torch.cuda.is_available():
model = model.cuda()

model.eval()

# Evaluate with chunking
# Evaluate with late chunking
tasks = [
task_cls(
chunked_pooling_enabled=True,
Expand All @@ -82,9 +75,11 @@ def main(model_name, strategy, task_name):
output_folder='results-chunked-pooling',
eval_splits=['test'],
overwrite_results=True,
encode_kwargs={'batch_size': 1},
batch_size=BATCH_SIZE,
encode_kwargs={'batch_size': BATCH_SIZE},
)

# Encode without late chunking
tasks = [
task_cls(
chunked_pooling_enabled=False,
Expand All @@ -106,7 +101,7 @@ def main(model_name, strategy, task_name):
output_folder='results-normal-pooling',
eval_splits=['test'],
overwrite_results=True,
encode_kwargs={'batch_size': 1},
encode_kwargs={'batch_size': BATCH_SIZE},
)


Expand Down
48 changes: 48 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import pytest

from mteb.abstasks.TaskMetadata import TaskMetadata

from chunked_pooling.mteb_chunked_eval import AbsTaskChunkedRetrieval


class DummyTask(AbsTaskChunkedRetrieval):
metadata = TaskMetadata(
dataset={
'path': '~',
'revision': '',
},
name='dummy',
description='',
type='Retrieval',
category='s2p',
reference=None,
eval_splits=[],
eval_langs=[],
main_score='ndcg_at_10',
date=None,
form=None,
domains=None,
task_subtypes=None,
license=None,
socioeconomic_status=None,
annotations_creators=None,
dialect=None,
text_creation=None,
bibtex_citation=None,
n_samples=None,
avg_character_length=None,
)

def load_data():
pass

def __init__(self, **kwargs):
super().__init__(**kwargs)


@pytest.fixture()
def dummy_task_factory():
def _create_dummy_task(*args, **kwargs):
return DummyTask(*args, **kwargs)

return _create_dummy_task
22 changes: 22 additions & 0 deletions tests/test_v3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from transformers import AutoTokenizer

from run_chunked_eval import load_model, DEFAULT_CHUNK_SIZE

MODEL_NAME = 'jinaai/jina-embeddings-v3'


def test_instruction_handling(dummy_task_factory):
model, has_instructions = load_model(MODEL_NAME)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
task = dummy_task_factory(
chunking_strategy='fixed',
chunk_size=DEFAULT_CHUNK_SIZE,
tokenizer=tokenizer,
model_has_instructions=has_instructions,
)
n_instruction_tokens = len(
tokenizer(model.get_instructions()[1], add_special_tokens=False)['input_ids']
)
annotations_one_token = task._calculate_annotations(model, ['A'])[0]
assert len(annotations_one_token) == 1
assert annotations_one_token[0] == (0, n_instruction_tokens + 3)

0 comments on commit 57699cb

Please sign in to comment.