Skip to content

Commit

Permalink
modify the shuffling of indices and adding concatdatasets instead of …
Browse files Browse the repository at this point in the history
…randomsampler
  • Loading branch information
edyoshikun committed Jan 30, 2025
1 parent 822dbba commit 1212a5a
Show file tree
Hide file tree
Showing 3 changed files with 170 additions and 179 deletions.
9 changes: 2 additions & 7 deletions viscy/data/hcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from torch.utils.data import DataLoader, Dataset

from viscy.data.typing import ChannelMap, DictTransform, HCSStackIndex, NormMeta, Sample
from viscy.utils.engine_state import set_fit_global_state

_logger = logging.getLogger("lightning.pytorch")

Expand Down Expand Up @@ -426,12 +427,6 @@ def setup(self, stage: Literal["fit", "validate", "test", "predict"]):
else:
raise NotImplementedError(f"{stage} stage")

def _set_fit_global_state(self, num_positions: int) -> torch.Tensor:
# disable metadata tracking in MONAI for performance
set_track_meta(False)
# shuffle positions, randomness is handled globally
return torch.randperm(num_positions)

def _setup_fit(self, dataset_settings: dict):
"""Set up the training and validation datasets."""
train_transform, val_transform = self._fit_transform()
Expand All @@ -441,7 +436,7 @@ def _setup_fit(self, dataset_settings: dict):

# shuffle positions, randomness is handled globally
positions = [pos for _, pos in plate.positions()]
shuffled_indices = self._set_fit_global_state(len(positions))
shuffled_indices = set_fit_global_state(len(positions))
positions = list(positions[i] for i in shuffled_indices)
num_train_fovs = int(len(positions) * self.split_ratio)
# training set needs to sample more Z range for augmentation
Expand Down
120 changes: 70 additions & 50 deletions viscy/data/tarrow.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from pathlib import Path
from typing import Callable

import numpy as np
import torch
from iohub.ngff import Position, open_ome_zarr
from lightning.pytorch import LightningDataModule
from tarrow.data.tarrow_dataset import TarrowDataset
from torch.utils.data import DataLoader
from torch.utils.data import ConcatDataset, DataLoader

from viscy.utils.engine_state import set_fit_global_state


class TarrowDataModule(LightningDataModule):
Expand All @@ -19,6 +21,8 @@ class TarrowDataModule(LightningDataModule):
Name of the channel to load
train_split : float, default=0.8
Fraction of data to use for training (0.0 to 1.0)
patch_size : tuple[int, int], default=(128, 128)
Patch size for TarrowDataset
batch_size : int, default=16
Batch size for dataloaders
num_workers : int, default=8
Expand All @@ -33,6 +37,8 @@ class TarrowDataModule(LightningDataModule):
Number of validation samples per epoch
resolution : int, default=0
Resolution level to load from OME-Zarr
normalization : function, optional (default=None)
Normalization function to apply to images
z_slice : int, default=0
Z-slice to load
pin_memory : bool, default=True
Expand All @@ -50,12 +56,14 @@ def __init__(
train_split: float = 0.8,
batch_size: int = 16,
num_workers: int = 8,
patch_size: tuple[int, int] = (128, 128),
prefetch_factor: int | None = None,
include_fov_names: list[str] = [],
train_samples_per_epoch: int = 100000,
val_samples_per_epoch: int = 10000,
resolution: int = 0,
z_slice: int = 0,
normalization: Callable[[np.ndarray], np.ndarray] | None = None,
pin_memory: bool = True,
persistent_workers: bool = True,
**kwargs,
Expand All @@ -67,12 +75,17 @@ def __init__(
self.batch_size = batch_size
self.num_workers = num_workers
self.prefetch_factor = prefetch_factor
self.path_size = patch_size
self.include_fov_names = include_fov_names
self.train_samples_per_epoch = train_samples_per_epoch
self.val_samples_per_epoch = val_samples_per_epoch
self.resolution = resolution
self.z_slice = z_slice
self.kwargs = kwargs
self.normalization = normalization

self._filter_positions()
self._channel_idx = self._get_channel_index()

def _get_channel_index(self, plate) -> int:
"""Get the index of the specified channel from the plate metadata.
Expand Down Expand Up @@ -102,15 +115,13 @@ def _get_channel_index(self, plate) -> int:
f"Channel '{self.channel_name}' not found. Available channels: {available_channels}"
)

def _load_images(
self, positions: list[Position], channel_idx: int
) -> list[np.ndarray]:
def _load_images(self, position: Position, channel_idx: int) -> list[np.ndarray]:
"""Load all images from positions into memory.
Parameters
----------
positions : list[Position]
List of positions to load
position : Position
Position to load
channel_idx : int
Index of channel to load
Expand All @@ -120,11 +131,10 @@ def _load_images(
List of 2D numpy arrays
"""
imgs = []
for pos in positions:
img_arr = pos[str(self.resolution)]
# Load all timepoints for this position
for t in range(len(img_arr)):
imgs.append(img_arr[t, channel_idx, self.z_slice])
img_arr = position[str(self.resolution)]
# Load all timepoints for this position
for t in range(len(img_arr)):
imgs.append(img_arr[t, channel_idx, self.z_slice])
return imgs

def setup(self, stage: str):
Expand All @@ -140,41 +150,34 @@ def setup(self, stage: str):
NotImplementedError
If stage is not "fit"
"""
plate = open_ome_zarr(self.ome_zarr_path, mode="r")

# Get channel index once
channel_idx = self._get_channel_index(plate)

# Get the positions to load
if self.include_fov_names:
positions = []
for fov_str, pos in plate.positions():
normalized_include_fovs = [
f.lstrip("/") for f in self.include_fov_names
]
if fov_str in normalized_include_fovs:
positions.append(pos)
else:
positions = [pos for _, pos in plate.positions()]
if stage == "fit":
list_dataset = []
for pos in self.positions:
pos_imgs = self._load_images(pos, self._channel_idx)
list_dataset.append(
TarrowDataset(
imgs=pos_imgs,
normalize=self.normalization,
size=self.path_size,
**self.kwargs,
)
)

# Load all images into memory using the pre-determined channel index
imgs = self._load_images(positions, channel_idx)
# Calculate split point
split_idx = int(len(self.positions) * self.train_split)

# Calculate split point
split_idx = int(len(imgs) * self.train_split)
# Shuffle the list of datasets
shuffled_indices = set_fit_global_state(len(list_dataset))
list_dataset = [list_dataset[i] for i in shuffled_indices]

if stage == "fit":
# Create training dataset with first train_split% of images
self.train_dataset = TarrowDataset(
imgs=imgs[:split_idx],
**self.kwargs,
)
self.train_dataset = ConcatDataset(list_dataset[:split_idx])

# Create validation dataset with remaining images
self.val_dataset = TarrowDataset(
imgs=imgs[split_idx:],
**{k: v for k, v in self.kwargs.items() if k != "augmenter"},
)
self.val_dataset = ConcatDataset(list_dataset[split_idx:])

elif stage == "test":
raise NotImplementedError(f"Invalid stage: {stage}")
Expand All @@ -183,25 +186,45 @@ def setup(self, stage: str):
else:
raise NotImplementedError(f"Invalid stage: {stage}")

def _filter_positions(self):
"""Filter positions based on include_fov_names."""
# Get the positions to load
plate = open_ome_zarr(self.ome_zarr_path, mode="r")
if self.include_fov_names:
positions = []
for fov_str, pos in plate.positions():
normalized_include_fovs = [
f.lstrip("/") for f in self.include_fov_names
]
if fov_str in normalized_include_fovs:
positions.append(pos)
else:
positions = [pos for _, pos in plate.positions()]

self.positions = positions

def _get_channel_index(self):
"""Get the index of the specified channel from the plate metadata."""
with open_ome_zarr(self.ome_zarr_path, mode="r") as plate:
_, first_pos = next(plate.positions())
return first_pos.channel_names.index(self.channel_name)

def train_dataloader(self):
"""Create the training dataloader.
Returns
-------
torch.utils.data.DataLoader
DataLoader for training data with random sampling
DataLoader for training data
"""
return DataLoader(
self.train_dataset,
sampler=torch.utils.data.RandomSampler(
self.train_dataset,
replacement=True,
num_samples=self.train_samples_per_epoch,
),
batch_size=self.batch_size,
num_workers=self.num_workers,
persistent_workers=True if self.num_workers > 0 else False,
prefetch_factor=self.prefetch_factor if self.num_workers else None,
pin_memory=True,
shuffle=True,
)

def val_dataloader(self):
Expand All @@ -210,19 +233,16 @@ def val_dataloader(self):
Returns
-------
torch.utils.data.DataLoader
DataLoader for validation data with random sampling
DataLoader for validation data
"""
return DataLoader(
self.val_dataset,
sampler=torch.utils.data.RandomSampler(
self.val_dataset,
replacement=True,
num_samples=self.val_samples_per_epoch,
),
batch_size=self.batch_size,
num_workers=self.num_workers,
persistent_workers=True if self.num_workers > 0 else False,
prefetch_factor=self.prefetch_factor if self.num_workers else None,
pin_memory=True,
shuffle=False,
)

def test_dataloader(self):
Expand Down
Loading

0 comments on commit 1212a5a

Please sign in to comment.