Skip to content

Commit

Permalink
Add caching option to dataloader
Browse files Browse the repository at this point in the history
  • Loading branch information
sanaAyrml committed Jan 9, 2025
1 parent 3b56645 commit e0c6d49
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 19 deletions.
17 changes: 13 additions & 4 deletions research/rxrx1/data/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,12 @@ def create_splits(


def load_rxrx1_data(
data_path: Path, client_num: int, batch_size: int, seed: int | None = None, train_val_split: float = 0.8
data_path: Path,
client_num: int,
batch_size: int,
seed: int | None = None,
train_val_split: float = 0.8,
num_workers: int = 0,
) -> tuple[DataLoader, DataLoader, dict[str, int]]:

# Read the CSV file
Expand All @@ -79,7 +84,7 @@ def load_rxrx1_data(

train_set, validation_set = create_splits(dataset, seed=seed, train_fraction=train_val_split)

train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True)
validation_loader = DataLoader(validation_set, batch_size=batch_size)
num_examples = {
"train_set": len(train_set),
Expand All @@ -89,13 +94,17 @@ def load_rxrx1_data(
return train_loader, validation_loader, num_examples


def load_rxrx1_test_data(data_path: Path, client_num: int, batch_size: int) -> tuple[DataLoader, dict[str, int]]:
def load_rxrx1_test_data(
data_path: Path, client_num: int, batch_size: int, num_workers: int = 0
) -> tuple[DataLoader, dict[str, int]]:

# Read the CSV file
data = pd.read_csv(f"{data_path}/clients/meta_data_{client_num+1}.csv")

evaluation_set = Rxrx1Dataset(metadata=data, root=data_path, dataset_type="test", transform=None)

evaluation_loader = DataLoader(evaluation_set, batch_size=batch_size, shuffle=False)
evaluation_loader = DataLoader(
evaluation_set, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True
)
num_examples = {"eval_set": len(evaluation_set)}
return evaluation_loader, num_examples
42 changes: 28 additions & 14 deletions research/rxrx1/data/dataset.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
from collections.abc import Callable
from pathlib import Path
from typing import Any

import pandas as pd
import torch
Expand All @@ -10,7 +11,14 @@


class Rxrx1Dataset(Dataset):
def __init__(self, metadata: pd.DataFrame, root: Path, dataset_type: str, transform: Callable | None = None):
def __init__(
self,
metadata: pd.DataFrame,
root: Path,
dataset_type: str,
transform: Callable | None = None,
cache_images: bool = False,
):
"""
Args:
metadata (DataFrame): A DataFrame containing image metadata.
Expand All @@ -20,33 +28,39 @@ def __init__(self, metadata: pd.DataFrame, root: Path, dataset_type: str, transf
"""
self.metadata = metadata[metadata["dataset"] == dataset_type]
self.root = root
self.transform = transform if transform else ToTensor()
self.transform = transform if transform else None

self.label_map = {label: idx for idx, label in enumerate(sorted(self.metadata["sirna_id"].unique()))}
self.metadata["mapped_label"] = self.metadata["sirna_id"].map(self.label_map)

if cache_images:
self.images = [self.load_image(dict(row)) for _, row in self.metadata.iterrows()]

def __len__(self) -> int:
return len(self.metadata)

def __getitem__(self, idx: int) -> tuple[torch.Tensor, int]:
row = self.metadata.iloc[idx]
row = dict(self.metadata.iloc[idx])

if hasattr(self, "images"):
image = self.images[idx]
else:
image = self.load_image(row)
label = row["mapped_label"]

return image, label

def load_image(self, row: dict[str, Any]) -> torch.Tensor:
experiment = row["experiment"]
plate = row["plate"]
well = row["well"]
site = row["site"]
label = row["mapped_label"] # Get the label index

images = []
for channel in range(1, 4):
image_path = os.path.join(self.root, f"images/{experiment}/Plate{plate}/{well}_s{site}_w{channel}.png")
image = self.load_image(image_path)
if not Path(image_path).exists():
raise FileNotFoundError(f"Image not found at {image_path}")
image = ToTensor(Image.open(image_path).convert("L"))
images.append(image)

concatenated_image = torch.cat(images, dim=0)
return concatenated_image, label

def load_image(self, path: str) -> torch.Tensor:
if not Path(path).exists():
raise FileNotFoundError(f"Image not found at {path}")
image = Image.open(path).convert("L") # Load as grayscale
return self.transform(image)
return torch.cat(images, dim=0)
4 changes: 3 additions & 1 deletion research/rxrx1/evaluate_on_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,9 @@ def main(
meta_data = pd.concat([meta_data, test_loader.dataset.metadata])
test_loader.dataset.metadata = meta_data

aggregated_test_loader = torch.utils.data.DataLoader(test_loader.dataset, batch_size=BATCH_SIZE, shuffle=False)
aggregated_test_loader = torch.utils.data.DataLoader(
test_loader.dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0, pin_memory=True
)
aggregated_num_examples = len(meta_data)

for client_number in range(NUM_CLIENTS):
Expand Down

0 comments on commit e0c6d49

Please sign in to comment.