diff --git a/cellarium/ml/core/pipeline.py b/cellarium/ml/core/pipeline.py index cb98a784..2750b1c3 100644 --- a/cellarium/ml/core/pipeline.py +++ b/cellarium/ml/core/pipeline.py @@ -57,7 +57,6 @@ def predict(self, batch: dict[str, np.ndarray | torch.Tensor]) -> dict[str, np.n raise TypeError(f"The last module in the pipeline must be an instance of {PredictMixin}. Got {model}") for module in self[:-1]: - print(type(module)) # get the module input keys ann = module.forward.__annotations__ input_keys = {key for key in ann if key != "return" and key in batch} diff --git a/cellarium/ml/models/contrastive_mlp.py b/cellarium/ml/models/contrastive_mlp.py index 921ef7b5..bf140ccd 100644 --- a/cellarium/ml/models/contrastive_mlp.py +++ b/cellarium/ml/models/contrastive_mlp.py @@ -11,10 +11,24 @@ from cellarium.ml.models.model import CellariumModel, PredictMixin from cellarium.ml.models.nt_xent import NT_Xent -import pdb - class ContrastiveMLP(CellariumModel, PredictMixin): + """ + Multilayer perceptron trained with contrastive learning. + + Args: + g_genes: + Number of genes in each entry (network input size). + hidden_size: + Dimensionality of the fully-connected hidden layers. + embed_dim: + Size of embedding (network output size). + world_size: + Number of devices used in training. + temperature: + Parameter governing Normalized Temperature-scaled cross-entropy (NT-Xent) loss. + """ + def __init__( self, g_genes: int, @@ -43,30 +57,44 @@ def __init__( self.reset_parameters() - + def reset_parameters(self) -> None: + for layer in self.layers: + if isinstance(layer, nn.Linear): + nn.init.kaiming_uniform_(layer.weight, mode="fan_in", nonlinearity="relu") + nn.init.constant_(layer.bias, 0.0) + elif isinstance(layer, nn.BatchNorm1d): + nn.init.constant_(layer.weight, 1.0) + nn.init.constant_(layer.bias, 0.0) + def forward(self, x_ng: torch.Tensor) -> dict[str, torch.Tensor]: + """ + Args: + x_ng: + Gene counts matrix. + Returns: + A dictionary with the loss value. + """ # compute deep embeddings z = F.normalize(self.layers(x_ng)) - # pdb.set_trace() - # split input into augmented halves z1, z2 = torch.chunk(z, 2) # SimCLR loss loss = self.Xent_loss(z1, z2) - return {'loss': loss} + return {"loss": loss} def predict(self, x_ng: torch.Tensor, **kwargs: Any): + """ + Send (transformed) data through the model and return outputs. + + Args: + x_ng: + Gene counts matrix. + Returns: + A dictionary with the embedding matrix. + """ with torch.no_grad(): + x_ng = torch.chunk(x_ng, 2)[0] z = F.normalize(self.layers(x_ng)) - return torch.chunk(z, 2)[0] - - def reset_parameters(self) -> None: - for layer in self.layers: - if isinstance(layer, nn.Linear): - nn.init.kaiming_uniform_(layer.weight, mode='fan_in', nonlinearity='relu') - nn.init.constant_(layer.bias, 0.0) - elif isinstance(layer, nn.BatchNorm1d): - nn.init.constant_(layer.weight, 1.0) - nn.init.constant_(layer.bias, 0.0) + return {"x_ng": z} diff --git a/cellarium/ml/models/nt_xent.py b/cellarium/ml/models/nt_xent.py index 2638a069..f1b4d4a2 100644 --- a/cellarium/ml/models/nt_xent.py +++ b/cellarium/ml/models/nt_xent.py @@ -7,18 +7,17 @@ from cellarium.ml.distributed.gather import GatherLayer from cellarium.ml.utilities.data import get_rank_and_num_replicas -import pdb - -import logging - -# logging.basicConfig(level=logging.DEBUG) -# logger = logging.getLogger() - class NT_Xent(nn.Module): """ Normalized Temperature-scaled cross-entropy loss. + **References:** + + 1. `A simple framework for contrastive learning of visual representations + (Chen, T., Kornblith, S., Norouzi, M., & Hinton, G.) + `_. + Args: batch_size: Expected batch size per distributed process. @@ -50,11 +49,10 @@ def _slice_negative_mask(self, size: int, rank: int) -> torch.Tensor: rank: The rank of the specified device. """ - negative_mask_full = ~torch.eye(size, dtype=bool).repeat((1, 2)) mask = torch.chunk(negative_mask_full, self.world_size, dim=0)[rank] return mask - + @staticmethod def _similarity_fn(z1: torch.Tensor, z2: torch.Tensor) -> torch.Tensor: """ @@ -66,22 +64,19 @@ def _similarity_fn(z1: torch.Tensor, z2: torch.Tensor) -> torch.Tensor: def forward(self, z_i: torch.Tensor, z_j: torch.Tensor) -> torch.Tensor: """ Gathers all inputs, then computes NT-Xent loss averaged over all - 2n augmented samples. Each sample's corresponding pair is used as - its positive class, while the remaining (2n - 2) samples are its - negative classes. + 2n augmented samples. """ - - # gather embeddings from distributed processing + # gather embeddings from distributed forward pass if self.world_size > 1: z_i_full = torch.cat(GatherLayer.apply(z_i), dim=0) z_j_full = torch.cat(GatherLayer.apply(z_j), dim=0) else: z_i_full = z_i z_j_full = z_j - - # pdb.set_trace() - - assert len(z_i_full) % self.world_size == 0, f'Expected batch to evenly divide across devices (set drop_last to True).' + + assert ( + len(z_i_full) % self.world_size == 0 + ), "Expected batch to evenly divide across devices (set drop_last to True)." batch_size = len(z_i_full) // self.world_size rank, _ = get_rank_and_num_replicas() @@ -89,6 +84,7 @@ def forward(self, z_i: torch.Tensor, z_j: torch.Tensor) -> torch.Tensor: z_both_full = torch.cat((z_i_full, z_j_full), dim=0) + # normalized similarity logits between device minibatch and full batch embeddings sim_i = NT_Xent._similarity_fn(z_i, z_both_full) / self.temperature sim_j = NT_Xent._similarity_fn(z_j, z_both_full) / self.temperature @@ -96,9 +92,9 @@ def forward(self, z_i: torch.Tensor, z_j: torch.Tensor) -> torch.Tensor: pos_j = torch.diag(sim_j, rank * batch_size) positive_samples = torch.cat((pos_i, pos_j)) - negative_samples = torch.cat([ - sim_i[negative_mask].reshape(batch_size, -1), - sim_j[negative_mask].reshape(batch_size, -1)]) + negative_samples = torch.cat( + [sim_i[negative_mask].reshape(batch_size, -1), sim_j[negative_mask].reshape(batch_size, -1)] + ) labels = torch.zeros_like(positive_samples).long() logits = torch.cat((positive_samples.unsqueeze(1), negative_samples), dim=1) diff --git a/cellarium/ml/transforms/binomial_resample.py b/cellarium/ml/transforms/binomial_resample.py index 45951e68..ca30a346 100644 --- a/cellarium/ml/transforms/binomial_resample.py +++ b/cellarium/ml/transforms/binomial_resample.py @@ -45,6 +45,6 @@ def forward(self, x_ng: torch.Tensor) -> torch.Tensor: p_apply_n = Bernoulli(probs=self.p_apply).sample(x_ng.shape[:1]).type_as(x_ng).bool() x_aug = Binomial(total_count=x_ng, probs=p_binom_ng).sample() - + x_ng = torch.where(p_apply_n.unsqueeze(1), x_ng, x_aug) - return {'x_ng': x_ng} + return {"x_ng": x_ng} diff --git a/cellarium/ml/transforms/dropout.py b/cellarium/ml/transforms/dropout.py index 9c97c15f..3387c0f4 100644 --- a/cellarium/ml/transforms/dropout.py +++ b/cellarium/ml/transforms/dropout.py @@ -46,6 +46,6 @@ def forward(self, x_ng: torch.Tensor) -> torch.Tensor: x_aug = torch.clone(x_ng) x_aug[Bernoulli(probs=p_dropout_ng).sample().bool()] = 0 - + x_ng = torch.where(p_apply_n.unsqueeze(1), x_ng, x_aug) - return {'x_ng': x_ng} + return {"x_ng": x_ng} diff --git a/cellarium/ml/transforms/duplicate.py b/cellarium/ml/transforms/duplicate.py index e010bb75..d75ebabc 100644 --- a/cellarium/ml/transforms/duplicate.py +++ b/cellarium/ml/transforms/duplicate.py @@ -23,4 +23,4 @@ def forward(self, x_ng: torch.Tensor) -> torch.Tensor: Returns: Duplicated counts. """ - return {'x_ng': x_ng.repeat((2, 1))} + return {"x_ng": x_ng.repeat((2, 1))} diff --git a/cellarium/ml/transforms/gaussian_noise.py b/cellarium/ml/transforms/gaussian_noise.py index 46a5780a..8850a900 100644 --- a/cellarium/ml/transforms/gaussian_noise.py +++ b/cellarium/ml/transforms/gaussian_noise.py @@ -43,8 +43,8 @@ def forward(self, x_ng: torch.Tensor) -> torch.Tensor: """ sigma_ng = Uniform(self.sigma_min, self.sigma_max).sample(x_ng.shape).type_as(x_ng) p_apply_n = Bernoulli(probs=self.p_apply).sample(x_ng.shape[:1]).type_as(x_ng).bool() - + x_aug = x_ng + Normal(0, sigma_ng).sample() x_ng = torch.where(p_apply_n.unsqueeze(1), x_ng, x_aug) - return {'x_ng': x_ng} + return {"x_ng": x_ng} diff --git a/cellarium/ml/transforms/randomize.py b/cellarium/ml/transforms/randomize.py index ad687dcc..3515c2e8 100644 --- a/cellarium/ml/transforms/randomize.py +++ b/cellarium/ml/transforms/randomize.py @@ -32,4 +32,4 @@ def forward(self, x_ng): Returns: Gene counts with randomly applied transform. """ - return self.transform(x_ng) if torch.rand(1) < self.p_apply else {'x_ng': x_ng} + return self.transform(x_ng) if torch.rand(1) < self.p_apply else {"x_ng": x_ng} diff --git a/cellarium/ml/utilities/core.py b/cellarium/ml/utilities/core.py index 5e67464b..63c4f751 100644 --- a/cellarium/ml/utilities/core.py +++ b/cellarium/ml/utilities/core.py @@ -3,6 +3,7 @@ import copy import math + import torch from cellarium.ml.utilities.testing import assert_nonnegative, assert_positive diff --git a/tests/test_contrastive_mlp.py b/tests/test_contrastive_mlp.py deleted file mode 100644 index 8cb53485..00000000 --- a/tests/test_contrastive_mlp.py +++ /dev/null @@ -1,192 +0,0 @@ -# Copyright Contributors to the Cellarium project. -# SPDX-License-Identifier: BSD-3-Clause - -import math -import os -from pathlib import Path - -import lightning.pytorch as pl -import numpy as np -import torch - -from cellarium.ml import CellariumAnnDataDataModule, CellariumModule -from cellarium.ml.models import ContrastiveMLP -from cellarium.ml.transforms import Duplicate -from cellarium.ml.utilities.data import collate_fn -from tests.common import BoringDataset - -from cellarium.ml.utilities.data import AnnDataField, get_rank_and_num_replicas - -import anndata - -import hashlib -import logging -import pickle - -logging.basicConfig(level=logging.DEBUG) -logger = logging.getLogger() - -def test_multi_gpu_consistency(tmp_path: Path): - class BasicNet(torch.nn.Module): - def __init__(self): - super(BasicNet, self).__init__() - self.net = torch.nn.Linear(4, 2) - - def forward(self, x_ng: torch.Tensor): - rank, num_replicas = get_rank_and_num_replicas() - - logger.debug(f'in {rank}/{num_replicas}') - logger.debug(x_ng) - out = self.net(x_ng) - logger.debug(f'out {rank}/{num_replicas}') - logger.debug(out) - loss_n = torch.norm(out, dim=1) - logger.debug(f'loss_n {rank}/{num_replicas}') - logger.debug(loss_n) - return {'loss': loss_n.mean()} - - class SimpleShift(torch.nn.Module): - def __init__(self, shift): - super().__init__() - self.shift = shift - - def forward(self, x_ng: torch.Tensor): - shift = torch.zeros_like(x_ng) - shift[x_ng.shape[0] // 2:, :] = self.shift - - return {'x_ng': x_ng + shift} - - class SimpleGaussian(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x_ng: torch.Tensor): - shift = torch.randn_like(x_ng) / 4 - - return {'x_ng': x_ng + shift} - - def get_cpu_state_dict(state_dict: dict) -> dict: - cpu_state_dict = dict() - for k, v in state_dict.items(): - if isinstance(v, torch.Tensor): - cpu_state_dict[k] = v.detach().cpu().numpy() - elif isinstance(v, dict): - cpu_state_dict[k] = get_cpu_state_dict(v) - else: - cpu_state_dict[k] = v - return cpu_state_dict - - def print_hash( - model, - optimizer, - scheduler): - - logger.debug('optimizer full') - logger.debug(optimizer) - - model_md5_hash = hashlib.md5(pickle.dumps(get_cpu_state_dict(module.state_dict()))).hexdigest() - - logger.debug(f'model MD5 hash: {model_md5_hash}') - - if optimizer is not None: - optimizer_md5_hash = hashlib.md5(pickle.dumps(get_cpu_state_dict(optimizer.state_dict()))).hexdigest() - logger.debug(f'optimizer MD5 hash: {optimizer_md5_hash}') - else: - logger.debug(f'optimizer: None') - - if scheduler is not None: - scheduler_md5_hash = hashlib.md5(pickle.dumps(get_cpu_state_dict(scheduler.state_dict()))).hexdigest() - logger.debug(f'scheduler MD5 hash: {scheduler_md5_hash}') - else: - logger.debug(f'scheduler: None') - - n_vars = 4 - hidden_size = [2] - embed_dim = 2 - data_size = 4 - full_batch_size = 4 - max_steps = 10 - - # dataset - rng = np.random.default_rng(123) - counts = rng.integers(1, 5, size=(data_size, n_vars)).astype("float32") - np.save('/home/jupyter/bw-bican-data/toy_ds.npy', counts) - adata = anndata.AnnData(counts) - adata.write('/home/jupyter/bw-bican-data/toy_ds.h5ad') - - # logger.debug(counts) - - # re-run test after switching to 1 or 2 - n_gpu = 2 - - init_args = { - "g_genes": n_vars, - "hidden_size": hidden_size, - "embed_dim": embed_dim, - "batch_size": full_batch_size // n_gpu, - "world_size": n_gpu, - } - seed = 5 - pl.seed_everything(seed) - # model = ContrastiveMLP(**init_args) # type: ignore[arg-type] - model = BasicNet() - config = { - "model": { - "model": { - "class_path": "cellarium.ml.models.ContrastiveMLP", - "init_args": init_args, - }, - 'optim_fn': torch.optim.Adam, - 'optim_kwargs': { - 'lr': 0.002, - }, - "transforms": [ - { - "class_path": "cellarium.ml.transforms.Duplicate" - }, - ], - } - } - path = f'/home/jupyter/bw-bican-data/gpu_experiments/ds-{data_size}__bs-{full_batch_size}__seed-{seed}__gpu-{n_gpu}' - - module = CellariumModule(transforms=[Duplicate()], model=model, optim_fn=config['model']['optim_fn'], optim_kwargs=config['model']['optim_kwargs'], config=config) - # module = CellariumModule(transforms=[Duplicate(), SimpleShift(0.5)], model=model, optim_fn=config['model']['optim_fn'], optim_kwargs=config['model']['optim_kwargs'], config=config) - # module = CellariumModule(model=model, optim_fn=config['model']['optim_fn'], optim_kwargs=config['model']['optim_kwargs']) - - trainer = pl.Trainer( - strategy='ddp', - accelerator="cpu", - devices=n_gpu, - max_steps=max_steps, - default_root_dir=path, - ) - data_module = CellariumAnnDataDataModule( - filenames='/home/jupyter/bw-bican-data/toy_ds.h5ad', - shard_size=data_size, - last_shard_size=data_size, - max_cache_size=1, - num_workers=1, - batch_keys={'x_ng': AnnDataField(attr='X')}, - batch_size=full_batch_size // n_gpu, - shuffle=False, - drop_last=False) - - # for p in module.model.parameters(): - # logger.debug(p.data) - - # print_hash(model, None, None) - - # fit - # trainer.fit(module, train_dataloaders=train_loader) - trainer.fit(module, datamodule=data_module) - - # run tests only for rank 0 - if trainer.global_rank != 0: - return - -# logger.debug('----------------AFTER TRAINING----------------') - - # for p in module.model.parameters(): - # logger.debug(p.data) - - assert False