Skip to content

Commit

Permalink
Refactor transforms, add tests for ContrastiveMLP
Browse files Browse the repository at this point in the history
Fix prediction_writer loop
  • Loading branch information
bricewang committed Aug 30, 2024
1 parent 39e2134 commit f7ed8a5
Show file tree
Hide file tree
Showing 9 changed files with 125 additions and 48 deletions.
9 changes: 3 additions & 6 deletions cellarium/ml/callbacks/prediction_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# SPDX-License-Identifier: BSD-3-Clause

import os
from collections.abc import Sequence, Mapping
from collections.abc import Mapping, Sequence
from pathlib import Path

import lightning.pytorch as pl
Expand Down Expand Up @@ -35,7 +35,7 @@ def write_prediction(
os.makedirs(output_dir, exist_ok=True)
df = pd.DataFrame(prediction.cpu())
if fields is not None:
for field_name, field_data in fields:
for field_name, field_data in fields.items():
df.insert(0, field_name, field_data)
df.insert(0, "db_ids", ids)
output_path = os.path.join(output_dir, f"batch_{postfix}.csv")
Expand All @@ -61,10 +61,7 @@ class PredictionWriter(pl.callbacks.BasePredictionWriter):
"""

def __init__(
self,
output_dir: Path | str,
prediction_size: int | None = None,
field_names: Sequence[str] | None = None
self, output_dir: Path | str, prediction_size: int | None = None, field_names: Sequence[str] | None = None
) -> None:
super().__init__(write_interval="batch")
self.output_dir = output_dir
Expand Down
4 changes: 2 additions & 2 deletions cellarium/ml/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ class FileLoader:
key: str | None = None
convert_fn: Callable[[Any], Any] | str | None = None

def __new__(cls, file_path, loader_fn, attr, key, convert_fn):
def __new__(cls, file_path, loader_fn, attr=None, key=None, convert_fn=None):
if isinstance(loader_fn, str):
loader_fn = import_object(loader_fn)
if loader_fn not in cached_loaders:
Expand Down Expand Up @@ -121,7 +121,7 @@ class CheckpointLoader(FileLoader):
key: str | None = None
convert_fn: Callable[[Any], Any] | str | None = None

def __new__(cls, file_path, attr, key, convert_fn):
def __new__(cls, file_path, attr=None, key=None, convert_fn=None):
return super().__new__(cls, file_path, CellariumModule.load_from_checkpoint, attr, key, convert_fn)


Expand Down
1 change: 0 additions & 1 deletion cellarium/ml/models/contrastive_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
# SPDX-License-Identifier: BSD-3-Clause

from collections.abc import Sequence
from typing import Any, List

import torch
import torch.nn.functional as F
Expand Down
22 changes: 7 additions & 15 deletions cellarium/ml/transforms/binomial_resample.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,10 @@

import torch
from torch import nn
from torch.distributions import Bernoulli, Binomial, Uniform
from torch.distributions import Binomial

import logging
logger = logging.getLogger(__name__)
from .randomize import Randomize

# Set the minimum level of messages to log
logger.setLevel(logging.ERROR)

class BinomialResample(nn.Module):
"""
Expand All @@ -29,14 +26,16 @@ class BinomialResample(nn.Module):
Lower bound on binomial distribution parameter.
p_binom_max:
Upper bound on binomial distribution parameter.
p_apply:
Probability of applying transform to each sample.
"""

def __init__(self, p_binom_min: float, p_binom_max: float, p_apply: float):
super().__init__()

self.p_binom_min = p_binom_min
self.p_binom_max = p_binom_max
self.p_apply = p_apply
self.randomize = Randomize(p_apply)

def forward(self, x_ng: torch.Tensor) -> dict[str, torch.Tensor]:
"""
Expand All @@ -46,15 +45,8 @@ def forward(self, x_ng: torch.Tensor) -> dict[str, torch.Tensor]:
Returns:
Binomially resampled gene counts.
"""
p_binom_min = torch.Tensor([self.p_binom_min]).type_as(x_ng)
p_binom_max = torch.Tensor([self.p_binom_max]).type_as(x_ng)
p_apply = torch.Tensor([self.p_apply]).type_as(x_ng)

p_binom_ng = Uniform(p_binom_min, p_binom_max).sample(x_ng.shape).squeeze(-1)

apply_mask_n1 = Bernoulli(probs=p_apply).sample(x_ng.shape[:1]).bool()

p_binom_ng = torch.empty_like(x_ng.shape).uniform_(self.p_binom_min, self.p_binom_max)
x_aug = Binomial(total_count=x_ng, probs=p_binom_ng).sample()

x_ng = torch.where(apply_mask_n1, x_aug, x_ng)
x_ng = self.randomize(x_aug, x_ng)
return {"x_ng": x_ng}
20 changes: 8 additions & 12 deletions cellarium/ml/transforms/dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

import torch
from torch import nn
from torch.distributions import Bernoulli, Uniform

from .randomize import Randomize


class Dropout(nn.Module):
Expand All @@ -24,14 +25,16 @@ class Dropout(nn.Module):
Lower bound on dropout parameter.
p_dropout_max:
Upper bound on dropout parameter.
p_apply:
Probability of applying transform to each sample.
"""

def __init__(self, p_dropout_min, p_dropout_max, p_apply):
super().__init__()

self.p_dropout_min = p_dropout_min
self.p_dropout_max = p_dropout_max
self.p_apply = p_apply
self.randomize = Randomize(p_apply)

def forward(self, x_ng: torch.Tensor) -> dict[str, torch.Tensor]:
"""
Expand All @@ -41,15 +44,8 @@ def forward(self, x_ng: torch.Tensor) -> dict[str, torch.Tensor]:
Returns:
Gene counts with random dropout.
"""
p_dropout_min = torch.Tensor([self.p_dropout_min]).type_as(x_ng)
p_dropout_max = torch.Tensor([self.p_dropout_max]).type_as(x_ng)
p_apply = torch.Tensor([self.p_apply]).type_as(x_ng)

p_dropout_ng = Uniform(p_dropout_min, p_dropout_max).sample(x_ng.shape).squeeze(-1)
p_apply_n1 = Bernoulli(probs=p_apply).sample(x_ng.shape[:1]).bool()

x_aug = torch.clone(x_ng)
x_aug[Bernoulli(probs=p_dropout_ng).sample().bool()] = 0
p_dropout_ng = torch.empty_like(x_ng.shape).uniform_(self.p_dropout_min, self.p_dropout_max)
x_aug = torch.where(torch.bernoulli(p_dropout_ng).bool(), 0, x_ng)

x_ng = torch.where(p_apply_n1, x_aug, x_ng)
x_ng = self.randomize(x_aug, x_ng)
return {"x_ng": x_ng}
21 changes: 9 additions & 12 deletions cellarium/ml/transforms/gaussian_noise.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

import torch
from torch import nn
from torch.distributions import Bernoulli, Normal, Uniform

from .randomize import Randomize


class GaussianNoise(nn.Module):
Expand All @@ -24,31 +25,27 @@ class GaussianNoise(nn.Module):
Lower bound on Gaussian sigma parameter.
sigma_max:
Upper bound on Gaussian sigma parameter.
p_apply:
Probability of applying transform to each sample.
"""

def __init__(self, sigma_min, sigma_max, p_apply):
super().__init__()

self.sigma_min = sigma_min
self.sigma_max = sigma_max
self.p_apply = p_apply
self.randomize = Randomize(p_apply)

def forward(self, x_ng: torch.Tensor) -> dict[str, torch.Tensor]:
"""
Args:
x_ng: Gene counts.
x_ng: Gene counts (log-transformed).
Returns:
Gene counts with added Gaussian noise.
"""
sigma_min = torch.Tensor([self.sigma_min]).type_as(x_ng)
sigma_max = torch.Tensor([self.sigma_max]).type_as(x_ng)
p_apply = torch.Tensor([self.p_apply]).type_as(x_ng)

sigma_ng = Uniform(sigma_min, sigma_max).sample(x_ng.shape).squeeze(-1)
p_apply_n1 = Bernoulli(probs=p_apply).sample(x_ng.shape[:1]).bool()

x_aug = x_ng + Normal(0, sigma_ng).sample()
sigma_ng = torch.empty_like(x_ng.shape).uniform_(self.sigma_min, self.sigma_max)
x_aug = x_ng + torch.normal(std=sigma_ng)

x_ng = torch.where(p_apply_n1, x_aug, x_ng)
x_ng = self.randomize(x_aug, x_ng)
return {"x_ng": x_ng}
37 changes: 37 additions & 0 deletions cellarium/ml/transforms/randomize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Copyright Contributors to the Cellarium project.
# SPDX-License-Identifier: BSD-3-Clause


import torch
from torch import nn


class Randomize(nn.Module):
"""
Randomly selects between the augmented and original data
for each sample according to probability p_apply.
Args:
p_apply:
Probability of selecting augmentation for each sample.
"""

def __init__(self, p_apply):
super().__init__()

self.p_apply = p_apply

def forward(self, x_aug: torch.Tensor, x_ng: torch.Tensor) -> torch.Tensor:
"""
Args:
x_aug: Augmented gene counts.
x_ng: Gene counts.
Returns:
Randomized augmented gene counts.
"""
p_apply_n1 = torch.Tensor([self.p_apply]).expand(x_ng.shape[0], 1).type_as(x_ng)
apply_mask_n1 = torch.bernoulli(p_apply_n1).bool()

x_ng = torch.where(apply_mask_n1, x_aug, x_ng)
return x_ng
1 change: 1 addition & 0 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@
"subcommand": "fit",
"fit": {
"model": {
"transforms": [{"class_path": "cellarium.ml.transforms.Duplicate"}],
"model": {
"class_path": "cellarium.ml.models.ContrastiveMLP",
"init_args": {
Expand Down
58 changes: 58 additions & 0 deletions tests/test_contrastive_mlp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Copyright Contributors to the Cellarium project.
# SPDX-License-Identifier: BSD-3-Clause

import math
import os
from pathlib import Path

import lightning.pytorch as pl
import numpy as np
import torch

from cellarium.ml import CellariumModule
from cellarium.ml.models import ContrastiveMLP
from cellarium.ml.transforms import Duplicate
from cellarium.ml.utilities.data import collate_fn
from tests.common import BoringDataset


def test_load_from_checkpoint_multi_device(tmp_path: Path):
n, g = 4, 3
devices = int(os.environ.get("TEST_DEVICES", "1"))
# dataloader
train_loader = torch.utils.data.DataLoader(
BoringDataset(np.arange(n * g).reshape(n, g).astype("float32")),
collate_fn=collate_fn,
)
# model
model = ContrastiveMLP(
n_obs=3,
embed_dim=2,
hidden_size=[2],
temperature=1.0,
)
module = CellariumModule(
transforms=[Duplicate()],
model=model,
optim_fn=torch.optim.Adam,
optim_kwargs={"lr": 1e-3},
)
# trainer
trainer = pl.Trainer(
accelerator="cpu",
devices=devices,
max_epochs=1,
default_root_dir=tmp_path,
)
# fit
trainer.fit(module, train_dataloaders=train_loader)

# run tests only for rank 0
if trainer.global_rank != 0:
return

# load model from checkpoint
ckpt_path = tmp_path / f"lightning_logs/version_0/checkpoints/epoch=0-step={math.ceil(n / devices)}.ckpt"
assert ckpt_path.is_file()
loaded_model = CellariumModule.load_from_checkpoint(ckpt_path).model
assert isinstance(loaded_model, ContrastiveMLP)

0 comments on commit f7ed8a5

Please sign in to comment.