Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cellarium gpt model #279

Open
wants to merge 20 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions cellarium/ml/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
LazyAnnData,
)
from cellarium.ml.data.fileio import read_h5ad_file, read_h5ad_gcs, read_h5ad_local, read_h5ad_url
from cellarium.ml.data.pytree_dataset import PyTreeDataset
from cellarium.ml.data.schema import AnnDataSchema

__all__ = [
Expand All @@ -16,6 +17,7 @@
"DistributedAnnDataCollectionView",
"IterableDistributedAnnDataCollectionDataset",
"LazyAnnData",
"PyTreeDataset",
"read_h5ad_file",
"read_h5ad_gcs",
"read_h5ad_local",
Expand Down
58 changes: 58 additions & 0 deletions cellarium/ml/data/pytree_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Copyright Contributors to the Cellarium project.
# SPDX-License-Identifier: BSD-3-Clause

import torch
from torch.utils._pytree import PyTree, tree_any, tree_iter, tree_map


class PyTreeDataset(torch.utils.data.Dataset):
"""
A dataset that wraps a PyTree of tensors and ndarrays.

Example::

import torch
from cellarium.ml.data import PyTreeDataset
from cellarium.ml.utilities.data import collate_fn

data = {
"gene_token_nc_dict": {
"gene_id": torch.randint(0, 10, (10, 3)),
"gene_value": torch.randint(0, 10, (10, 3)),
},
"gene_token_mask_nc": torch.randint(0, 10, (10, 3)),
"metadata_token_nc_dict": {
"cell_type": torch.randint(0, 10, (10, 3)),
},
"metadata_token_mask_nc_dict": {
"cell_type": torch.randint(0, 10, (10, 3)),
},
"prompt_mask_nc": torch.randint(0, 10, (10, 3)),
}
dataset = PyTreeDataset(data)
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=batch_size,
collate_fn=collate_fn,
)
for batch in dataloader:
...

Args:
pytree: A PyTree of tensors and ndarrays.
"""

def __init__(self, pytree: PyTree) -> None:
self._length: int = next(tree_iter(pytree)).shape[0] # type: ignore[call-overload]
if tree_any(lambda x: x.shape[0] != self._length, pytree):
raise ValueError("All tensors must have the same batch dimension")
self.pytree = pytree

def __getitem__(self, index: int) -> PyTree:
return tree_map(lambda data: data[index], self.pytree)

def __getitems__(self, indices: list[int]) -> list[PyTree]:
return [tree_map(lambda data: data[indices], self.pytree)]

def __len__(self) -> int:
return self._length
47 changes: 26 additions & 21 deletions cellarium/ml/layers/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ class GeneExpressionEmbedding(nn.Module):
Args:
categorical_vocab_sizes:
Categorical gene token vocabulary sizes.
continuous_vocab_sizes:
Continuous gene token vocabulary sizes.
continuous_tokens:
Continuous gene tokens.
d_model:
Dimensionality of the embeddings and hidden states.
embeddings_initializer:
Expand All @@ -27,34 +27,39 @@ class GeneExpressionEmbedding(nn.Module):
def __init__(
self,
categorical_vocab_sizes: dict[str, int],
continuous_vocab_sizes: dict[str, int],
continuous_tokens: list[str],
d_model: int,
embeddings_initializer: dict[str, Any],
) -> None:
super().__init__()
self.E = nn.ModuleDict()
self.E.update({key: nn.Embedding(vocab_size, d_model) for key, vocab_size in categorical_vocab_sizes.items()})
self.E.update(
{key: nn.Linear(vocab_size, d_model, bias=False) for key, vocab_size in continuous_vocab_sizes.items()}
self.embedding_dict = nn.ModuleDict()
self.embedding_dict.update(
{key: nn.Embedding(vocab_size, d_model) for key, vocab_size in categorical_vocab_sizes.items()}
)
self.embedding_dict.update({key: nn.Linear(1, d_model, bias=False) for key in continuous_tokens})
self.categorical_vocab_sizes = categorical_vocab_sizes
self.continuous_tokens = continuous_tokens
self.embeddings_initializer = embeddings_initializer

self._reset_parameters()

def _reset_parameters(self) -> None:
for module in self.E.children():
for module in self.embedding_dict.children():
create_initializer(self.embeddings_initializer)(module.weight)

def forward(self, gene_tokens_nc: dict[str, torch.Tensor]) -> torch.Tensor:
def forward(self, gene_tokens_nc_dict: dict[str, torch.Tensor]) -> torch.Tensor:
"""
Args:
gene_tokens_nc:
gene_tokens_nc_dict:
Dictionary of gene token tensors of shape ``(n, c)``.

Returns:
The gene embedding tensor of shape ``(n, c, d)``.
"""
return sum(self.E[key](gene_token_nc) for key, gene_token_nc in gene_tokens_nc.items())
return sum(
self.embedding_dict[key](gene_token_nc.unsqueeze(-1) if key in self.continuous_tokens else gene_token_nc)
for key, gene_token_nc in gene_tokens_nc_dict.items()
)


class MetadataEmbedding(nn.Module):
Expand All @@ -77,27 +82,27 @@ def __init__(
embeddings_initializer: dict[str, Any],
) -> None:
super().__init__()
self.E = nn.ModuleDict(
self.embedding_dict = nn.ModuleDict(
{key: nn.Embedding(vocab_size, d_model) for key, vocab_size in categorical_vocab_sizes.items()}
)
self.embeddings_initializer = embeddings_initializer

self._reset_parameters()

def _reset_parameters(self) -> None:
for module in self.E.children():
for module in self.embedding_dict.children():
create_initializer(self.embeddings_initializer)(module.weight)

def forward(self, metadata_tokens_n: dict[str, torch.Tensor]) -> torch.Tensor:
def forward(self, metadata_tokens_nc_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
"""
Args:
metadata_token_n:
Dictionary of metadata token tensors of shape ``(n,)``.
metadata_token_nc_dict:
Dictionary of metadata token tensors of shape ``(n, c)``.

Returns:
The metadata embedding tensor of shape ``(n, m, d)``.
Dictionary of metadata embedding tensors of shape ``(n, c, d)``.
"""
return torch.stack(
[self.E[key](metadata_token_n) for key, metadata_token_n in metadata_tokens_n.items()],
dim=1,
)
return {
key: self.embedding_dict[key](metadata_token_nc)
for key, metadata_token_nc in metadata_tokens_nc_dict.items()
}
6 changes: 3 additions & 3 deletions cellarium/ml/layers/head.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def __init__(
heads_initializer: dict[str, Any],
) -> None:
super().__init__()
self.W = nn.ModuleDict(
self.readout_dict = nn.ModuleDict(
{key: nn.Linear(d_model, vocab_size, use_bias) for key, vocab_size in categorical_vocab_sizes.items()}
)
self.output_logits_scale = output_logits_scale
Expand All @@ -44,7 +44,7 @@ def __init__(
self._reset_parameters()

def _reset_parameters(self) -> None:
for module in self.W.children():
for module in self.readout_dict.children():
create_initializer(self.heads_initializer)(module.weight)

if module.bias is not None:
Expand All @@ -59,4 +59,4 @@ def forward(self, hidden_state_ncd: torch.Tensor) -> dict[str, torch.Tensor]:
Returns:
Dictionary of output logits tensors of shape ``(n, c, vocab_size)``.
"""
return {key: self.output_logits_scale * self.W[key](hidden_state_ncd) for key in self.W}
return {key: self.output_logits_scale * self.readout_dict[key](hidden_state_ncd) for key in self.readout_dict}
2 changes: 2 additions & 0 deletions cellarium/ml/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright Contributors to the Cellarium project.
# SPDX-License-Identifier: BSD-3-Clause

from cellarium.ml.models.cellarium_gpt import CellariumGPT
from cellarium.ml.models.geneformer import Geneformer
from cellarium.ml.models.incremental_pca import IncrementalPCA
from cellarium.ml.models.logistic_regression import LogisticRegression
Expand All @@ -10,6 +11,7 @@
from cellarium.ml.models.tdigest import TDigest

__all__ = [
"CellariumGPT",
"CellariumModel",
"Geneformer",
"IncrementalPCA",
Expand Down
Loading