Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement contrastive learning model and transforms #195

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions cellarium/ml/callbacks/prediction_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# SPDX-License-Identifier: BSD-3-Clause

import os
from collections.abc import Sequence
from collections.abc import Mapping, Sequence
from pathlib import Path

import lightning.pytorch as pl
Expand All @@ -16,6 +16,7 @@ def write_prediction(
ids: np.ndarray,
output_dir: Path | str,
postfix: int | str,
fields: Mapping[str, np.ndarray | torch.Tensor] | None = None,
) -> None:
"""
Write prediction to a CSV file.
Expand All @@ -33,6 +34,9 @@ def write_prediction(
if not os.path.exists(output_dir):
os.makedirs(output_dir, exist_ok=True)
df = pd.DataFrame(prediction.cpu())
if fields is not None:
for field_name, field_data in fields.items():
df.insert(0, field_name, field_data)
df.insert(0, "db_ids", ids)
output_path = os.path.join(output_dir, f"batch_{postfix}.csv")
df.to_csv(output_path, header=False, index=False)
Expand All @@ -56,10 +60,13 @@ class PredictionWriter(pl.callbacks.BasePredictionWriter):
written. If not ``None``, only the first ``prediction_size`` columns will be written.
"""

def __init__(self, output_dir: Path | str, prediction_size: int | None = None) -> None:
def __init__(
self, output_dir: Path | str, prediction_size: int | None = None, field_names: Sequence[str] | None = None
) -> None:
super().__init__(write_interval="batch")
self.output_dir = output_dir
self.prediction_size = prediction_size
self.field_names = field_names

def write_on_batch_end(
self,
Expand All @@ -76,9 +83,14 @@ def write_on_batch_end(
x_ng = x_ng[:, : self.prediction_size]

assert isinstance(batch["obs_names_n"], np.ndarray)
if self.field_names is None:
fields = None
else:
fields = {field_name: batch[field_name] for field_name in self.field_names}
write_prediction(
prediction=x_ng,
ids=batch["obs_names_n"],
output_dir=self.output_dir,
postfix=batch_idx * trainer.world_size + trainer.global_rank,
fields=fields,
)
62 changes: 59 additions & 3 deletions cellarium/ml/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,10 @@ class FileLoader:
file_path: str
loader_fn: Callable[[str], Any] | str
attr: str | None = None
key: str | None = None
convert_fn: Callable[[Any], Any] | str | None = None

def __new__(cls, file_path, loader_fn, attr, convert_fn):
def __new__(cls, file_path, loader_fn, attr=None, key=None, convert_fn=None):
if isinstance(loader_fn, str):
loader_fn = import_object(loader_fn)
if loader_fn not in cached_loaders:
Expand All @@ -75,6 +76,8 @@ def __new__(cls, file_path, loader_fn, attr, convert_fn):

if attr is not None:
obj = attrgetter(attr)(obj)
if key is not None:
obj = obj[key]

if isinstance(convert_fn, str):
convert_fn = import_object(convert_fn)
Expand Down Expand Up @@ -115,10 +118,11 @@ class CheckpointLoader(FileLoader):

file_path: str
attr: str | None = None
key: str | None = None
convert_fn: Callable[[Any], Any] | str | None = None

def __new__(cls, file_path, attr, convert_fn):
return super().__new__(cls, file_path, CellariumModule.load_from_checkpoint, attr, convert_fn)
def __new__(cls, file_path, attr=None, key=None, convert_fn=None):
return super().__new__(cls, file_path, CellariumModule.load_from_checkpoint, attr, key, convert_fn)


def file_loader_constructor(loader: yaml.SafeLoader, node: yaml.nodes.MappingNode) -> FileLoader:
Expand Down Expand Up @@ -182,6 +186,19 @@ def compute_n_obs(data: CellariumAnnDataDataModule) -> int:
return data.dadc.n_obs


def compute_n_vars(data: CellariumAnnDataDataModule) -> int:
"""
Compute the number of observations in the data.

Args:
data: A :class:`CellariumAnnDataDataModule` instance.

Returns:
The number of variables in the data.
"""
return data.dadc.n_vars


def compute_y_categories(data: CellariumAnnDataDataModule) -> np.ndarray:
"""
Compute the categories in the target variable.
Expand Down Expand Up @@ -596,6 +613,45 @@ def tdigest(args: ArgsType = None) -> None:
cli(args=args)


@register_model
def contrastive_mlp(args: ArgsType = None) -> None:
r"""
CLI to run the :class:`cellarium.ml.models.ContrastiveMLP` model.

This example shows how to perform contrastive learning with a default augmentation
strategy for omics data.

Example run::

cellarium-ml contrastive_mlp fit \
--model.model.init_args.hidden_size 4096 2048 1024 512 \
--model.model.init_args.embed_dim 256 \
--model.model.init_args.temperature 1.0 \
--model.model.init_args.target_count 10000 \
--data.filenames "gs://dsp-cellarium-cas-public/test-data/test_{0..3}.h5ad" \
--data.shard_size 100 \
--data.max_cache_size 2 \
--data.batch_size 100 \
--data.num_workers 4 \
--trainer.accelerator gpu \
--trainer.devices 1 \
--trainer.default_root_dir runs/contrastive \

Args:
args: Arguments to parse. If ``None`` the arguments are taken from ``sys.argv``.
"""
cli = lightning_cli_factory(
"cellarium.ml.models.ContrastiveMLP",
link_arguments=[
LinkArguments("data", "model.model.init_args.n_obs", compute_n_vars),
],
trainer_defaults={
"max_epochs": 20,
},
)
cli(args=args)


def main(args: ArgsType = None) -> None:
"""
CLI that dispatches to the appropriate model cli based on the model name in ``args`` and runs it.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
# Copyright Contributors to the Cellarium project.
# SPDX-License-Identifier: BSD-3-Clause

from cellarium.ml.distributed.gather import GatherLayer
from cellarium.ml.losses.nt_xent import NT_Xent

__all__ = [
"GatherLayer",
"NT_Xent",
]
104 changes: 104 additions & 0 deletions cellarium/ml/losses/nt_xent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
# Copyright Contributors to the Cellarium project.
# SPDX-License-Identifier: BSD-3-Clause

import torch
from torch import nn

from cellarium.ml.utilities.distributed import GatherLayer, get_rank_and_num_replicas


class NT_Xent(nn.Module):
"""
Normalized Temperature-scaled cross-entropy loss.

**References:**

1. `A simple framework for contrastive learning of visual representations
(Chen, T., Kornblith, S., Norouzi, M., & Hinton, G.)
<https://arxiv.org/abs/2002.05709>`_.

Args:
batch_size:
Expected batch size per distributed process.
world_size:
Number of distributed processes.
temperature:
Logit scaling coefficient. A higher temperature reduces
the scale of the output logits, resulting in a more volatile
update step.
"""

def __init__(
self,
temperature: float = 1.0,
):
super(NT_Xent, self).__init__()

self.temperature = temperature
self.criterion = nn.CrossEntropyLoss(reduction="mean")

def _slice_negative_mask(self, size: int, rank: int) -> torch.Tensor:
"""
Returns row slice of full negative mask corresponding to the segment
of the full batch held by the specified device.

Args:
rank:
The rank of the specified device.
"""
_, world_size = get_rank_and_num_replicas()

negative_mask_full = ~torch.eye(size).bool().repeat((1, 2))
mask = torch.chunk(negative_mask_full, world_size, dim=0)[rank]
return mask

@staticmethod
def _similarity_fn(z1: torch.Tensor, z2: torch.Tensor) -> torch.Tensor:
"""
Computes cosine similarity between normalized vectors,
which is equivalent to a standard inner product.
"""
return torch.einsum("nc,mc->nm", z1, z2)

def forward(self, z_i: torch.Tensor, z_j: torch.Tensor) -> torch.Tensor:
"""
Gathers all inputs, then computes NT-Xent loss averaged over all
2n augmented samples.
"""
_, world_size = get_rank_and_num_replicas()

# gather embeddings from distributed forward pass
if world_size > 1:
z_i_full = torch.cat(GatherLayer.apply(z_i), dim=0)
z_j_full = torch.cat(GatherLayer.apply(z_j), dim=0)
else:
z_i_full = z_i
z_j_full = z_j

assert (
len(z_i_full) % world_size == 0
), "Expected batch to evenly divide across devices (set drop_last to True)."

batch_size = len(z_i_full) // world_size
rank, _ = get_rank_and_num_replicas()
negative_mask = self._slice_negative_mask(len(z_i_full), rank)

z_both_full = torch.cat((z_i_full, z_j_full), dim=0)

# normalized similarity logits between device minibatch and full batch embeddings
sim_i = NT_Xent._similarity_fn(z_i, z_both_full) / self.temperature
sim_j = NT_Xent._similarity_fn(z_j, z_both_full) / self.temperature

pos_i = torch.diag(sim_i, (world_size + rank) * batch_size)
pos_j = torch.diag(sim_j, rank * batch_size)

positive_samples = torch.cat((pos_i, pos_j))
negative_samples = torch.cat(
[sim_i[negative_mask].reshape(batch_size, -1), sim_j[negative_mask].reshape(batch_size, -1)]
)

labels = torch.zeros_like(positive_samples).long()
logits = torch.cat((positive_samples.unsqueeze(1), negative_samples), dim=1)
loss = self.criterion(logits, labels)

return loss
2 changes: 2 additions & 0 deletions cellarium/ml/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright Contributors to the Cellarium project.
# SPDX-License-Identifier: BSD-3-Clause

from cellarium.ml.models.contrastive_mlp import ContrastiveMLP
from cellarium.ml.models.geneformer import Geneformer
from cellarium.ml.models.incremental_pca import IncrementalPCA
from cellarium.ml.models.logistic_regression import LogisticRegression
Expand All @@ -12,6 +13,7 @@

__all__ = [
"CellariumModel",
"ContrastiveMLP",
"Geneformer",
"IncrementalPCA",
"LogisticRegression",
Expand Down
91 changes: 91 additions & 0 deletions cellarium/ml/models/contrastive_mlp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# Copyright Contributors to the Cellarium project.
# SPDX-License-Identifier: BSD-3-Clause

from collections.abc import Sequence

import torch
import torch.nn.functional as F
from torch import nn

from cellarium.ml.losses.nt_xent import NT_Xent
from cellarium.ml.models.model import CellariumModel, PredictMixin


class ContrastiveMLP(CellariumModel, PredictMixin):
"""
Multilayer perceptron trained with contrastive learning.

Args:
n_obs:
Number of observations in each entry (network input size).
hidden_size:
Dimensionality of the fully-connected hidden layers.
embed_dim:
Size of embedding (network output size).
temperature:
Parameter governing Normalized Temperature-scaled cross-entropy (NT-Xent) loss.
"""

def __init__(
self,
n_obs: int,
hidden_size: Sequence[int],
embed_dim: int,
temperature: float = 1.0,
):
super(ContrastiveMLP, self).__init__()

self.layers = nn.Sequential()
self.layers.append(nn.Linear(n_obs, hidden_size[0]))
self.layers.append(nn.BatchNorm1d(hidden_size[0]))
self.layers.append(nn.ReLU())
for size_i, size_j in zip(hidden_size[:-1], hidden_size[1:]):
self.layers.append(nn.Linear(size_i, size_j))
self.layers.append(nn.BatchNorm1d(size_j))
self.layers.append(nn.ReLU())
self.layers.append(nn.Linear(hidden_size[-1], embed_dim))

self.Xent_loss = NT_Xent(temperature)

self.reset_parameters()

def reset_parameters(self) -> None:
for layer in self.layers:
if isinstance(layer, nn.Linear):
nn.init.kaiming_uniform_(layer.weight, mode="fan_in", nonlinearity="relu")
nn.init.constant_(layer.bias, 0.0)
elif isinstance(layer, nn.BatchNorm1d):
nn.init.constant_(layer.weight, 1.0)
nn.init.constant_(layer.bias, 0.0)

def forward(self, x_ng: torch.Tensor) -> dict[str, torch.Tensor]:
"""
Args:
x_ng:
Gene counts matrix.
Returns:
A dictionary with the loss value.
"""
# compute deep embeddings
z = F.normalize(self.layers(x_ng))

# split input into augmented halves
z1, z2 = torch.chunk(z, 2)

# SimCLR loss
loss = self.Xent_loss(z1, z2)
return {"loss": loss}

def predict(self, x_ng: torch.Tensor):
"""
Sends (transformed) data through the model and returns outputs.

Args:
x_ng:
Gene counts matrix.
Returns:
A dictionary with the embedding matrix.
"""
with torch.no_grad():
z = F.normalize(self.layers(x_ng))
return {"x_ng": z}
Loading
Loading