Skip to content

Commit

Permalink
Add documentation and formatting to contrastive model files
Browse files Browse the repository at this point in the history
  • Loading branch information
bricewang committed May 31, 2024
1 parent de668c3 commit 0226328
Show file tree
Hide file tree
Showing 10 changed files with 70 additions and 238 deletions.
1 change: 0 additions & 1 deletion cellarium/ml/core/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ def predict(self, batch: dict[str, np.ndarray | torch.Tensor]) -> dict[str, np.n
raise TypeError(f"The last module in the pipeline must be an instance of {PredictMixin}. Got {model}")

for module in self[:-1]:
print(type(module))
# get the module input keys
ann = module.forward.__annotations__
input_keys = {key for key in ann if key != "return" and key in batch}
Expand Down
60 changes: 44 additions & 16 deletions cellarium/ml/models/contrastive_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,24 @@
from cellarium.ml.models.model import CellariumModel, PredictMixin
from cellarium.ml.models.nt_xent import NT_Xent

import pdb


class ContrastiveMLP(CellariumModel, PredictMixin):
"""
Multilayer perceptron trained with contrastive learning.
Args:
g_genes:
Number of genes in each entry (network input size).
hidden_size:
Dimensionality of the fully-connected hidden layers.
embed_dim:
Size of embedding (network output size).
world_size:
Number of devices used in training.
temperature:
Parameter governing Normalized Temperature-scaled cross-entropy (NT-Xent) loss.
"""

def __init__(
self,
g_genes: int,
Expand Down Expand Up @@ -43,30 +57,44 @@ def __init__(

self.reset_parameters()


def reset_parameters(self) -> None:
for layer in self.layers:
if isinstance(layer, nn.Linear):
nn.init.kaiming_uniform_(layer.weight, mode="fan_in", nonlinearity="relu")
nn.init.constant_(layer.bias, 0.0)
elif isinstance(layer, nn.BatchNorm1d):
nn.init.constant_(layer.weight, 1.0)
nn.init.constant_(layer.bias, 0.0)

def forward(self, x_ng: torch.Tensor) -> dict[str, torch.Tensor]:
"""
Args:
x_ng:
Gene counts matrix.
Returns:
A dictionary with the loss value.
"""
# compute deep embeddings
z = F.normalize(self.layers(x_ng))

# pdb.set_trace()

# split input into augmented halves
z1, z2 = torch.chunk(z, 2)

# SimCLR loss
loss = self.Xent_loss(z1, z2)
return {'loss': loss}
return {"loss": loss}

def predict(self, x_ng: torch.Tensor, **kwargs: Any):
"""
Send (transformed) data through the model and return outputs.
Args:
x_ng:
Gene counts matrix.
Returns:
A dictionary with the embedding matrix.
"""
with torch.no_grad():
x_ng = torch.chunk(x_ng, 2)[0]
z = F.normalize(self.layers(x_ng))
return torch.chunk(z, 2)[0]

def reset_parameters(self) -> None:
for layer in self.layers:
if isinstance(layer, nn.Linear):
nn.init.kaiming_uniform_(layer.weight, mode='fan_in', nonlinearity='relu')
nn.init.constant_(layer.bias, 0.0)
elif isinstance(layer, nn.BatchNorm1d):
nn.init.constant_(layer.weight, 1.0)
nn.init.constant_(layer.bias, 0.0)
return {"x_ng": z}
38 changes: 17 additions & 21 deletions cellarium/ml/models/nt_xent.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,17 @@
from cellarium.ml.distributed.gather import GatherLayer
from cellarium.ml.utilities.data import get_rank_and_num_replicas

import pdb

import logging

# logging.basicConfig(level=logging.DEBUG)
# logger = logging.getLogger()


class NT_Xent(nn.Module):
"""
Normalized Temperature-scaled cross-entropy loss.
**References:**
1. `A simple framework for contrastive learning of visual representations
(Chen, T., Kornblith, S., Norouzi, M., & Hinton, G.)
<https://arxiv.org/abs/2002.05709>`_.
Args:
batch_size:
Expected batch size per distributed process.
Expand Down Expand Up @@ -50,11 +49,10 @@ def _slice_negative_mask(self, size: int, rank: int) -> torch.Tensor:
rank:
The rank of the specified device.
"""

negative_mask_full = ~torch.eye(size, dtype=bool).repeat((1, 2))
mask = torch.chunk(negative_mask_full, self.world_size, dim=0)[rank]
return mask

@staticmethod
def _similarity_fn(z1: torch.Tensor, z2: torch.Tensor) -> torch.Tensor:
"""
Expand All @@ -66,39 +64,37 @@ def _similarity_fn(z1: torch.Tensor, z2: torch.Tensor) -> torch.Tensor:
def forward(self, z_i: torch.Tensor, z_j: torch.Tensor) -> torch.Tensor:
"""
Gathers all inputs, then computes NT-Xent loss averaged over all
2n augmented samples. Each sample's corresponding pair is used as
its positive class, while the remaining (2n - 2) samples are its
negative classes.
2n augmented samples.
"""

# gather embeddings from distributed processing
# gather embeddings from distributed forward pass
if self.world_size > 1:
z_i_full = torch.cat(GatherLayer.apply(z_i), dim=0)
z_j_full = torch.cat(GatherLayer.apply(z_j), dim=0)
else:
z_i_full = z_i
z_j_full = z_j
# pdb.set_trace()

assert len(z_i_full) % self.world_size == 0, f'Expected batch to evenly divide across devices (set drop_last to True).'

assert (
len(z_i_full) % self.world_size == 0
), "Expected batch to evenly divide across devices (set drop_last to True)."

batch_size = len(z_i_full) // self.world_size
rank, _ = get_rank_and_num_replicas()
negative_mask = self._slice_negative_mask(len(z_i_full), rank)

z_both_full = torch.cat((z_i_full, z_j_full), dim=0)

# normalized similarity logits between device minibatch and full batch embeddings
sim_i = NT_Xent._similarity_fn(z_i, z_both_full) / self.temperature
sim_j = NT_Xent._similarity_fn(z_j, z_both_full) / self.temperature

pos_i = torch.diag(sim_i, (self.world_size + rank) * batch_size)
pos_j = torch.diag(sim_j, rank * batch_size)

positive_samples = torch.cat((pos_i, pos_j))
negative_samples = torch.cat([
sim_i[negative_mask].reshape(batch_size, -1),
sim_j[negative_mask].reshape(batch_size, -1)])
negative_samples = torch.cat(
[sim_i[negative_mask].reshape(batch_size, -1), sim_j[negative_mask].reshape(batch_size, -1)]
)

labels = torch.zeros_like(positive_samples).long()
logits = torch.cat((positive_samples.unsqueeze(1), negative_samples), dim=1)
Expand Down
4 changes: 2 additions & 2 deletions cellarium/ml/transforms/binomial_resample.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,6 @@ def forward(self, x_ng: torch.Tensor) -> torch.Tensor:
p_apply_n = Bernoulli(probs=self.p_apply).sample(x_ng.shape[:1]).type_as(x_ng).bool()

x_aug = Binomial(total_count=x_ng, probs=p_binom_ng).sample()

x_ng = torch.where(p_apply_n.unsqueeze(1), x_ng, x_aug)
return {'x_ng': x_ng}
return {"x_ng": x_ng}
4 changes: 2 additions & 2 deletions cellarium/ml/transforms/dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,6 @@ def forward(self, x_ng: torch.Tensor) -> torch.Tensor:

x_aug = torch.clone(x_ng)
x_aug[Bernoulli(probs=p_dropout_ng).sample().bool()] = 0

x_ng = torch.where(p_apply_n.unsqueeze(1), x_ng, x_aug)
return {'x_ng': x_ng}
return {"x_ng": x_ng}
2 changes: 1 addition & 1 deletion cellarium/ml/transforms/duplicate.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,4 @@ def forward(self, x_ng: torch.Tensor) -> torch.Tensor:
Returns:
Duplicated counts.
"""
return {'x_ng': x_ng.repeat((2, 1))}
return {"x_ng": x_ng.repeat((2, 1))}
4 changes: 2 additions & 2 deletions cellarium/ml/transforms/gaussian_noise.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ def forward(self, x_ng: torch.Tensor) -> torch.Tensor:
"""
sigma_ng = Uniform(self.sigma_min, self.sigma_max).sample(x_ng.shape).type_as(x_ng)
p_apply_n = Bernoulli(probs=self.p_apply).sample(x_ng.shape[:1]).type_as(x_ng).bool()

x_aug = x_ng + Normal(0, sigma_ng).sample()

x_ng = torch.where(p_apply_n.unsqueeze(1), x_ng, x_aug)
return {'x_ng': x_ng}
return {"x_ng": x_ng}
2 changes: 1 addition & 1 deletion cellarium/ml/transforms/randomize.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,4 @@ def forward(self, x_ng):
Returns:
Gene counts with randomly applied transform.
"""
return self.transform(x_ng) if torch.rand(1) < self.p_apply else {'x_ng': x_ng}
return self.transform(x_ng) if torch.rand(1) < self.p_apply else {"x_ng": x_ng}
1 change: 1 addition & 0 deletions cellarium/ml/utilities/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import copy
import math

import torch

from cellarium.ml.utilities.testing import assert_nonnegative, assert_positive
Expand Down
Loading

0 comments on commit 0226328

Please sign in to comment.