From e0c6d4909eb34109a6355d1a0fd2cdb069347547 Mon Sep 17 00:00:00 2001 From: sanaAyrml Date: Thu, 9 Jan 2025 06:58:07 -0500 Subject: [PATCH] Add caching option to dataloader --- research/rxrx1/data/data_utils.py | 17 +++++++++--- research/rxrx1/data/dataset.py | 42 ++++++++++++++++++++---------- research/rxrx1/evaluate_on_test.py | 4 ++- 3 files changed, 44 insertions(+), 19 deletions(-) diff --git a/research/rxrx1/data/data_utils.py b/research/rxrx1/data/data_utils.py index 73f995052..9415af3b9 100644 --- a/research/rxrx1/data/data_utils.py +++ b/research/rxrx1/data/data_utils.py @@ -69,7 +69,12 @@ def create_splits( def load_rxrx1_data( - data_path: Path, client_num: int, batch_size: int, seed: int | None = None, train_val_split: float = 0.8 + data_path: Path, + client_num: int, + batch_size: int, + seed: int | None = None, + train_val_split: float = 0.8, + num_workers: int = 0, ) -> tuple[DataLoader, DataLoader, dict[str, int]]: # Read the CSV file @@ -79,7 +84,7 @@ def load_rxrx1_data( train_set, validation_set = create_splits(dataset, seed=seed, train_fraction=train_val_split) - train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True) + train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True) validation_loader = DataLoader(validation_set, batch_size=batch_size) num_examples = { "train_set": len(train_set), @@ -89,13 +94,17 @@ def load_rxrx1_data( return train_loader, validation_loader, num_examples -def load_rxrx1_test_data(data_path: Path, client_num: int, batch_size: int) -> tuple[DataLoader, dict[str, int]]: +def load_rxrx1_test_data( + data_path: Path, client_num: int, batch_size: int, num_workers: int = 0 +) -> tuple[DataLoader, dict[str, int]]: # Read the CSV file data = pd.read_csv(f"{data_path}/clients/meta_data_{client_num+1}.csv") evaluation_set = Rxrx1Dataset(metadata=data, root=data_path, dataset_type="test", transform=None) - evaluation_loader = DataLoader(evaluation_set, batch_size=batch_size, shuffle=False) + evaluation_loader = DataLoader( + evaluation_set, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True + ) num_examples = {"eval_set": len(evaluation_set)} return evaluation_loader, num_examples diff --git a/research/rxrx1/data/dataset.py b/research/rxrx1/data/dataset.py index 34cf48d11..bc6cf9a1b 100644 --- a/research/rxrx1/data/dataset.py +++ b/research/rxrx1/data/dataset.py @@ -1,6 +1,7 @@ import os from collections.abc import Callable from pathlib import Path +from typing import Any import pandas as pd import torch @@ -10,7 +11,14 @@ class Rxrx1Dataset(Dataset): - def __init__(self, metadata: pd.DataFrame, root: Path, dataset_type: str, transform: Callable | None = None): + def __init__( + self, + metadata: pd.DataFrame, + root: Path, + dataset_type: str, + transform: Callable | None = None, + cache_images: bool = False, + ): """ Args: metadata (DataFrame): A DataFrame containing image metadata. @@ -20,33 +28,39 @@ def __init__(self, metadata: pd.DataFrame, root: Path, dataset_type: str, transf """ self.metadata = metadata[metadata["dataset"] == dataset_type] self.root = root - self.transform = transform if transform else ToTensor() + self.transform = transform if transform else None self.label_map = {label: idx for idx, label in enumerate(sorted(self.metadata["sirna_id"].unique()))} self.metadata["mapped_label"] = self.metadata["sirna_id"].map(self.label_map) + if cache_images: + self.images = [self.load_image(dict(row)) for _, row in self.metadata.iterrows()] + def __len__(self) -> int: return len(self.metadata) def __getitem__(self, idx: int) -> tuple[torch.Tensor, int]: - row = self.metadata.iloc[idx] + row = dict(self.metadata.iloc[idx]) + + if hasattr(self, "images"): + image = self.images[idx] + else: + image = self.load_image(row) + label = row["mapped_label"] + + return image, label + + def load_image(self, row: dict[str, Any]) -> torch.Tensor: experiment = row["experiment"] plate = row["plate"] well = row["well"] site = row["site"] - label = row["mapped_label"] # Get the label index images = [] for channel in range(1, 4): image_path = os.path.join(self.root, f"images/{experiment}/Plate{plate}/{well}_s{site}_w{channel}.png") - image = self.load_image(image_path) + if not Path(image_path).exists(): + raise FileNotFoundError(f"Image not found at {image_path}") + image = ToTensor(Image.open(image_path).convert("L")) images.append(image) - - concatenated_image = torch.cat(images, dim=0) - return concatenated_image, label - - def load_image(self, path: str) -> torch.Tensor: - if not Path(path).exists(): - raise FileNotFoundError(f"Image not found at {path}") - image = Image.open(path).convert("L") # Load as grayscale - return self.transform(image) + return torch.cat(images, dim=0) diff --git a/research/rxrx1/evaluate_on_test.py b/research/rxrx1/evaluate_on_test.py index 396c3acb6..155215286 100644 --- a/research/rxrx1/evaluate_on_test.py +++ b/research/rxrx1/evaluate_on_test.py @@ -74,7 +74,9 @@ def main( meta_data = pd.concat([meta_data, test_loader.dataset.metadata]) test_loader.dataset.metadata = meta_data - aggregated_test_loader = torch.utils.data.DataLoader(test_loader.dataset, batch_size=BATCH_SIZE, shuffle=False) + aggregated_test_loader = torch.utils.data.DataLoader( + test_loader.dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0, pin_memory=True + ) aggregated_num_examples = len(meta_data) for client_number in range(NUM_CLIENTS):