Skip to content

Commit

Permalink
Fix data processing issues with deep learning embedding pipelines.
Browse files Browse the repository at this point in the history
  • Loading branch information
bojan-karlas committed Apr 13, 2024
1 parent aa5b8de commit 3713ff3
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 10 deletions.
28 changes: 21 additions & 7 deletions experiments/datascope/experiments/pipelines/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from transformers.image_processing_utils import BatchFeature, BaseImageProcessor
from transformers.modeling_outputs import BaseModelOutputWithPoolingAndNoAttention
from transformers.modeling_utils import PreTrainedModel
from typing import Dict, Iterable, Type, Optional, Union, List, Tuple
from typing import Dict, Iterable, Type, Optional, Union, List, Tuple, Callable

from .utility import TorchImageDataset, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from ..datasets import Dataset, TabularDatasetMixin, ImageDatasetMixin, TextDatasetMixin
Expand Down Expand Up @@ -286,48 +286,62 @@ def get_preprocessor(cls: Type["ImageEmbeddingPipeline"]) -> transforms.Transfor
def get_model(cls: Type["ImageEmbeddingPipeline"]) -> PreTrainedModel:
pass

@staticmethod
def model_forward(
model: PreTrainedModel, batch: Union[Dict[str, torch.Tensor], List[torch.Tensor], torch.Tensor]
) -> torch.Tensor:
output: BaseModelOutputWithPoolingAndNoAttention = model(batch)
return torch.squeeze(output.pooler_output, dim=(2, 3))

@staticmethod
def _embedding_transform(
X: np.ndarray, cuda_mode: bool, preprocessor: transforms.Transform, model: PreTrainedModel
X: np.ndarray,
cuda_mode: bool,
preprocessor: transforms.Transform,
model: PreTrainedModel,
model_forward_function: Callable[
[PreTrainedModel, Union[Dict[str, torch.Tensor], List[torch.Tensor], torch.Tensor]], torch.Tensor
],
) -> np.ndarray:

results: List[np.ndarray] = []
batch_size = BATCH_SIZE
success = False
device = "cuda:0" if cuda_mode else "cpu"
model = model.to(device)
dataset = TorchImageDataset(X, None, preprocessor, device=device)

while not success:
try:
loader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=False)
for batch in loader:
with torch.no_grad():
output: BaseModelOutputWithPoolingAndNoAttention = model(batch)
result = output.pooler_output
result = model_forward_function(model, batch)
if cuda_mode:
result = result.cpu()
results.append(np.squeeze(result.numpy(), axis=(2, 3)))
results.append(result.numpy())
success = True

except torch.cuda.OutOfMemoryError: # type: ignore
batch_size = batch_size // 2
results = []
print("New batch size: ", batch_size)

model.cpu()
torch.cuda.empty_cache()
return np.concatenate(results)

@classmethod
def construct(cls: Type["ImageEmbeddingPipeline"], dataset: Dataset) -> "ImageEmbeddingPipeline":
cuda_mode = torch.cuda.is_available()
preprocessor = cls.get_preprocessor()
model = cls.get_model()
if cuda_mode:
model = model.to("cuda:0")
embedding_transform = functools.partial(
ImageEmbeddingPipeline._embedding_transform,
cuda_mode=cuda_mode,
preprocessor=preprocessor,
model=model,
model_forward_function=cls.model_forward,
)

ops = [("embedding", FunctionTransformer(embedding_transform))]
Expand Down
10 changes: 7 additions & 3 deletions experiments/datascope/experiments/pipelines/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@

IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
OPENAI_CLIP_MEAN = (0.48145466, 0.4578275, 0.40821073)
OPENAI_CLIP_STD = (0.26862954, 0.26130258, 0.27577711)
OPENAI_DEFAULT_MEAN = (0.48145466, 0.4578275, 0.40821073)
OPENAI_DEFAULT_STD = (0.26862954, 0.26130258, 0.27577711)


class TorchImageDataset(torch.utils.data.Dataset):
Expand Down Expand Up @@ -50,7 +50,11 @@ def __getitem__(self, idx: Union[int, Sequence[int]]):
return {"pixel_values": X_items, "labels": y_items}

def __getitems__(self, idx: Sequence[int]):
return self.__getitem__(idx)
if self.y is None:
return list(self.__getitem__(idx))
else:
result = self.__getitem__(idx)
return {"pixel_values": list(result["pixel_values"]), "labels": list(result["labels"])}

def __len__(self):
return len(self.y)

0 comments on commit 3713ff3

Please sign in to comment.