diff --git a/cellarium/ml/layers/__init__.py b/cellarium/ml/layers/__init__.py index 189275a4..da43b264 100644 --- a/cellarium/ml/layers/__init__.py +++ b/cellarium/ml/layers/__init__.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: BSD-3-Clause from cellarium.ml.layers.attention import MultiHeadAttention -from cellarium.ml.layers.embedding import GeneExpressionEmbedding, MetadataEmbedding +from cellarium.ml.layers.embedding import TokenEmbedding from cellarium.ml.layers.ffn import PositionWiseFFN from cellarium.ml.layers.head import MultiHeadReadout from cellarium.ml.layers.mu_linear import MuLinear @@ -10,8 +10,7 @@ from cellarium.ml.layers.transformer import Transformer, TransformerBlock __all__ = [ - "GeneExpressionEmbedding", - "MetadataEmbedding", + "TokenEmbedding", "MuLinear", "MultiHeadAttention", "MultiHeadReadout", diff --git a/cellarium/ml/layers/embedding.py b/cellarium/ml/layers/embedding.py index 6db31786..453c1789 100644 --- a/cellarium/ml/layers/embedding.py +++ b/cellarium/ml/layers/embedding.py @@ -9,15 +9,15 @@ from cellarium.ml.utilities.layers import create_initializer -class GeneExpressionEmbedding(nn.Module): +class TokenEmbedding(nn.Module): """ - Gene embedding. + Gene and metadata tokens embedding. Args: categorical_vocab_sizes: - Categorical gene token vocabulary sizes. - continuous_vocab_sizes: - Continuous gene token vocabulary sizes. + Categorical token vocabulary sizes. + continuous_tokens: + Continuous tokens. d_model: Dimensionality of the embeddings and hidden states. embeddings_initializer: @@ -27,77 +27,43 @@ class GeneExpressionEmbedding(nn.Module): def __init__( self, categorical_vocab_sizes: dict[str, int], - continuous_vocab_sizes: dict[str, int], + continuous_tokens: list[str], d_model: int, embeddings_initializer: dict[str, Any], ) -> None: super().__init__() - self.E = nn.ModuleDict() - self.E.update({key: nn.Embedding(vocab_size, d_model) for key, vocab_size in categorical_vocab_sizes.items()}) - self.E.update( - {key: nn.Linear(vocab_size, d_model, bias=False) for key, vocab_size in continuous_vocab_sizes.items()} - ) - self.embeddings_initializer = embeddings_initializer - - self._reset_parameters() - - def _reset_parameters(self) -> None: - for module in self.E.children(): - create_initializer(self.embeddings_initializer)(module.weight) - - def forward(self, gene_tokens_nc: dict[str, torch.Tensor]) -> torch.Tensor: - """ - Args: - gene_tokens_nc: - Dictionary of gene token tensors of shape ``(n, c)``. - - Returns: - The gene embedding tensor of shape ``(n, c, d)``. - """ - return sum(self.E[key](gene_token_nc) for key, gene_token_nc in gene_tokens_nc.items()) - - -class MetadataEmbedding(nn.Module): - """ - Metadata embedding. - - Args: - categorical_vocab_sizes: - Categorical metadata token vocabulary sizes. - d_model: - Dimensionality of the embeddings and hidden states. - initializer: - Initializer for the embeddings. - """ - - def __init__( - self, - categorical_vocab_sizes: dict[str, int], - d_model: int, - embeddings_initializer: dict[str, Any], - ) -> None: - super().__init__() - self.E = nn.ModuleDict( + self.embedding_dict = nn.ModuleDict() + self.embedding_dict.update( {key: nn.Embedding(vocab_size, d_model) for key, vocab_size in categorical_vocab_sizes.items()} ) + self.embedding_dict.update({key: nn.Linear(1, d_model, bias=False) for key in continuous_tokens}) + self.categorical_vocab_sizes = categorical_vocab_sizes + self.continuous_tokens = continuous_tokens self.embeddings_initializer = embeddings_initializer self._reset_parameters() def _reset_parameters(self) -> None: - for module in self.E.children(): + for module in self.embedding_dict.children(): create_initializer(self.embeddings_initializer)(module.weight) - def forward(self, metadata_tokens_n: dict[str, torch.Tensor]) -> torch.Tensor: + def forward( + self, + token_value_nc_dict: dict[str, torch.Tensor], + token_mask_nc_dict: dict[str, torch.Tensor], + ) -> torch.Tensor: """ Args: - metadata_token_n: - Dictionary of metadata token tensors of shape ``(n,)``. + token_value_nc_dict: + Dictionary of token value tensors of shape ``(n, c)``. + token_mask_nc_dict: + Dictionary of token mask tensors of shape ``(n, c)``. Returns: - The metadata embedding tensor of shape ``(n, m, d)``. + Embedding tensor of shape ``(n, c, d)``. """ - return torch.stack( - [self.E[key](metadata_token_n) for key, metadata_token_n in metadata_tokens_n.items()], - dim=1, + return sum( + self.embedding_dict[key](token_value_nc.unsqueeze(-1) if key in self.continuous_tokens else token_value_nc) + * token_mask_nc_dict[key].unsqueeze(-1) + for i, (key, token_value_nc) in enumerate(token_value_nc_dict.items()) ) diff --git a/cellarium/ml/layers/head.py b/cellarium/ml/layers/head.py index 1d6b94e3..aad1d8e7 100644 --- a/cellarium/ml/layers/head.py +++ b/cellarium/ml/layers/head.py @@ -35,7 +35,7 @@ def __init__( heads_initializer: dict[str, Any], ) -> None: super().__init__() - self.W = nn.ModuleDict( + self.readout_dict = nn.ModuleDict( {key: nn.Linear(d_model, vocab_size, use_bias) for key, vocab_size in categorical_vocab_sizes.items()} ) self.output_logits_scale = output_logits_scale @@ -44,7 +44,7 @@ def __init__( self._reset_parameters() def _reset_parameters(self) -> None: - for module in self.W.children(): + for module in self.readout_dict.children(): create_initializer(self.heads_initializer)(module.weight) if module.bias is not None: @@ -59,4 +59,4 @@ def forward(self, hidden_state_ncd: torch.Tensor) -> dict[str, torch.Tensor]: Returns: Dictionary of output logits tensors of shape ``(n, c, vocab_size)``. """ - return {key: self.output_logits_scale * self.W[key](hidden_state_ncd) for key in self.W} + return {key: self.output_logits_scale * self.readout_dict[key](hidden_state_ncd) for key in self.readout_dict} diff --git a/cellarium/ml/models/__init__.py b/cellarium/ml/models/__init__.py index e5b05abf..d97dd110 100644 --- a/cellarium/ml/models/__init__.py +++ b/cellarium/ml/models/__init__.py @@ -1,6 +1,7 @@ # Copyright Contributors to the Cellarium project. # SPDX-License-Identifier: BSD-3-Clause +from cellarium.ml.models.cellarium_gpt import CellariumGPT from cellarium.ml.models.geneformer import Geneformer from cellarium.ml.models.incremental_pca import IncrementalPCA from cellarium.ml.models.logistic_regression import LogisticRegression @@ -10,6 +11,7 @@ from cellarium.ml.models.tdigest import TDigest __all__ = [ + "CellariumGPT", "CellariumModel", "Geneformer", "IncrementalPCA", diff --git a/cellarium/ml/models/cellarium_gpt.py b/cellarium/ml/models/cellarium_gpt.py new file mode 100644 index 00000000..b04a2bdd --- /dev/null +++ b/cellarium/ml/models/cellarium_gpt.py @@ -0,0 +1,361 @@ +# Copyright Contributors to the Cellarium project. +# SPDX-License-Identifier: BSD-3-Clause + +from typing import Literal + +import lightning.pytorch as pl +import numpy as np +import torch +from torch import nn +from torch.nn.attention.flex_attention import BlockMask, create_block_mask + +from cellarium.ml.layers import MultiHeadReadout, TokenEmbedding, Transformer +from cellarium.ml.models.model import CellariumModel, PredictMixin, ValidateMixin +from cellarium.ml.utilities.layers import scale_initializers_by_dimension +from cellarium.ml.utilities.mup import LRAdjustmentGroup + +try: + from cerebras.pytorch.backend import use_cs +except ImportError: + + def use_cs() -> bool: + return False + + +def prompt_diagonal_mask(prompt_mask_nc: torch.Tensor) -> torch.Tensor: + """ + Generate a prompt diagonal mask for self-attention. + + Args: + prompt_mask_nc: + The prompt mask. + + Returns: + torch.Tensor: The prompt diagonal mask. + + Example: + + For prompt_mask = [True, False, True, False, False], the attention mask is: + + [[True, False, True, False, False], + [True, True, True, False, False], + [True, False, True, False, False], + [True, False, True, True, False], + [True, False, True, False, True]] + """ + device = prompt_mask_nc.device + n, c = prompt_mask_nc.shape + if use_cs(): + c_range = torch.arange(c, device=device, dtype=torch.float32) + diag_mask_ncc = (c_range[:, None].expand(n, -1, 1) - c_range.expand(n, 1, -1)).abs() + prompt_mask_n1c = 1 - prompt_mask_nc[:, None, :].float() + attention_mask_ncc = diag_mask_ncc * prompt_mask_n1c + return attention_mask_ncc == 0 + else: + diag_mask_cc = torch.eye(c, dtype=torch.bool, device=device) + attention_mask_ncc = prompt_mask_nc[:, None, :] | diag_mask_cc + return attention_mask_ncc + + +class CellariumGPT(CellariumModel, PredictMixin, ValidateMixin): + """ + CellariumGPT model. + + Args: + gene_vocab_sizes: + Gene token vocabulary sizes. Must include "gene_value" and "gene_id". Additionally, it can include + experimental conditions, such as "assay" and "suspension_type". + metadata_vocab_sizes: + Metadata token vocabulary sizes. This can include metadata tokens such as "cell_type", "tissue", + "sex", "development_stage", and "disease". + d_model: + Dimensionality of the embeddings and hidden states. + d_ffn: + Dimensionality of the inner feed-forward layers. + n_heads: + Number of attention heads. + n_blocks: + Number of transformer blocks. + dropout_p: + Dropout probability. + use_bias: + Whether to use bias in the linear transformations. + attention_backend: + Backend for the attention computation. + attention_softmax_fp32: + Whether to use float32 for softmax computation when ``torch`` backend is used. + loss_scales: + A dictionary of loss scales for each label type. + initializer_range: + The standard deviation of the truncated normal initializer. + embeddings_scale: + Multiplier for the embeddings. + attention_logits_scale: + Multiplier for the attention logits. + output_logits_scale: + Multiplier for the output logits. + mup_base_d_model: + Base dimensionality of the model for muP. + mup_base_d_ffn: + Base dimensionality of the inner feed-forward layers for muP. + """ + + def __init__( + self, + # Vocab sizes + gene_vocab_sizes: dict[str, int], + metadata_vocab_sizes: dict[str, int], + # Model parameters + d_model: int, + d_ffn: int, + n_heads: int, + n_blocks: int, + dropout_p: float, + use_bias: bool, + attention_backend: Literal["flex", "math", "mem_efficient", "torch"], + attention_softmax_fp32: bool, + loss_scales: dict[str, float], + # Tunable parameters + initializer_range: float = 0.02, + embeddings_scale: float = 1.0, + attention_logits_scale: float = 1.0, + output_logits_scale: float = 1.0, + # muP (maximal update parameterization) parameters + mup_base_d_model: int | None = None, + mup_base_d_ffn: int | None = None, + ) -> None: + super().__init__() + + # Vocab sizes + self.gene_vocab_sizes = gene_vocab_sizes + self.metadata_vocab_sizes = metadata_vocab_sizes + + # Initializers + self.initializer_range = initializer_range + default_initializer = { + "name": "trunc_normal_", + "mean": 0.0, + "std": self.initializer_range, + "a": -2 * self.initializer_range, + "b": 2 * self.initializer_range, + } + embeddings_initializer = default_initializer.copy() + Wqkv_initializer = default_initializer.copy() + Wo_initializer = default_initializer.copy() + dense1_initializer = default_initializer.copy() + dense2_initializer = default_initializer.copy() + heads_initializer = default_initializer.copy() + self.lr_adjustment_groups = { + "embedding": LRAdjustmentGroup("*embedding*weight"), + "decoder_attention": LRAdjustmentGroup("*transformer*attention*W*weight"), + "decoder_input_ffn": LRAdjustmentGroup("*transformer*ffn.dense1*weight"), + "decoder_output_ffn": LRAdjustmentGroup("*transformer*ffn.dense2*weight"), + } + + # Multipliers + self.embeddings_scale = embeddings_scale + self.attention_logits_scale = attention_logits_scale + self.output_logits_scale = output_logits_scale + + # Handle muP scaling for Adam and AdamW optimizers + if mup_base_d_model: + d_model_width_mult = d_model / mup_base_d_model + scale_initializers_by_dimension( + [Wqkv_initializer, dense1_initializer], + width_scale=d_model_width_mult**-0.5, + ) + scale_initializers_by_dimension( + Wo_initializer, + width_scale=d_model_width_mult**-0.5, + depth_scale=(2 * n_blocks) ** -0.5, + ) + self.output_logits_scale /= d_model_width_mult + for lr_adjustment_group in [ + "decoder_attention", + "decoder_input_ffn", + ]: + self.lr_adjustment_groups[lr_adjustment_group].set_scale(1 / d_model_width_mult) + self.width_mult = d_model_width_mult + else: + scale_initializers_by_dimension( + Wo_initializer, + depth_scale=(2 * n_blocks) ** -0.5, + ) + + if mup_base_d_ffn: + d_ffn_width_mult = d_ffn / mup_base_d_ffn + scale_initializers_by_dimension( + dense2_initializer, + width_scale=d_ffn_width_mult**-0.5, + depth_scale=(2 * n_blocks) ** -0.5, + ) + self.lr_adjustment_groups["decoder_output_ffn"].set_scale(1 / d_ffn_width_mult) + assert self.width_mult == d_ffn_width_mult + else: + scale_initializers_by_dimension( + dense2_initializer, + depth_scale=(2 * n_blocks) ** -0.5, + ) + + gene_categorical_vocab_sizes = gene_vocab_sizes.copy() + gene_value_vocab_size = gene_categorical_vocab_sizes.pop("gene_value") # used for the readout head + self.token_embedding = TokenEmbedding( + categorical_vocab_sizes=gene_categorical_vocab_sizes + # Add 1 to the vocab size for the metadata tokens to account for the mask token + | {key: vocab_size + 1 for key, vocab_size in metadata_vocab_sizes.items()}, + continuous_tokens=["gene_value", "gene_query_mask", "total_mrna_umis"], + d_model=d_model, + embeddings_initializer=embeddings_initializer, + ) + self.transformer = Transformer( + d_model=d_model, + d_ffn=d_ffn, + use_bias=use_bias, + n_heads=n_heads, + n_blocks=n_blocks, + dropout_p=dropout_p, + attention_logits_scale=attention_logits_scale, + attention_backend=attention_backend, + attention_softmax_fp32=attention_softmax_fp32, + Wqkv_initializer=Wqkv_initializer, + Wo_initializer=Wo_initializer, + dense1_initializer=dense1_initializer, + dense2_initializer=dense2_initializer, + ) + self.head = MultiHeadReadout( + categorical_vocab_sizes={"gene_value": gene_value_vocab_size, **metadata_vocab_sizes}, + d_model=d_model, + use_bias=use_bias, + output_logits_scale=output_logits_scale, + heads_initializer=heads_initializer, + ) + self.loss_scales = loss_scales + + self.reset_parameters() + + def reset_parameters(self) -> None: + def _reset_parameters(module): + return getattr(module, "_reset_parameters", lambda: None)() + + self.apply(_reset_parameters) + + @property + def d_model(self) -> int: + return self.transformer.blocks[0].d_model + + @property + def d_ffn(self) -> int: + return self.transformer.blocks[0].d_ffn + + @property + def n_heads(self) -> int: + return self.transformer.blocks[0].n_heads + + @property + def n_blocks(self) -> int: + return len(self.transformer.blocks) + + @property + def attention_backend(self) -> Literal["flex", "math", "mem_efficient", "torch"]: + return self.transformer.blocks[0].attention.attention_backend + + @attention_backend.setter + def attention_backend(self, value: Literal["flex", "math", "mem_efficient", "torch"]) -> None: + for block in self.transformer.blocks: + block.attention.attention_backend = value + + def predict( + self, + token_value_nc_dict: dict[str, torch.Tensor], + token_mask_nc_dict: dict[str, torch.Tensor], + prompt_mask_nc: torch.Tensor, + ) -> dict[str, np.ndarray | torch.Tensor]: + """ + Args: + token_value_nc_dict: + Dictionary of token value tensors of shape ``(n, c)``. + token_mask_nc_dict: + Dictionary of token mask tensors of shape ``(n, c)``. + + Returns: + Dictionary of logits tensors of shape ``(n, c, k)``. + """ + # Create embeddings + embedding_ncd = self.token_embedding(token_value_nc_dict, token_mask_nc_dict) + + # Create attention mask + attention_mask_ncc: torch.Tensor | BlockMask + if self.attention_backend == "flex": + + def prompt_diagonal_mask_mod(b, h, q_idx, kv_idx): + return prompt_mask_nc[b, kv_idx] | (q_idx == kv_idx) + + n, c = prompt_mask_nc.shape + attention_mask_ncc = create_block_mask(prompt_diagonal_mask_mod, B=n, H=None, Q_LEN=c, KV_LEN=c) + else: + attention_mask_ncc = prompt_diagonal_mask(prompt_mask_nc) + + # Transformer blocks + hidden_state_ncd = embedding_ncd * self.embeddings_scale + hidden_state_ncd = self.transformer(hidden_state_ncd, attention_mask_ncc) + + # Compute logits + logits_nck_dict = self.head(hidden_state_ncd) + + return logits_nck_dict + + def forward( + self, + token_value_nc_dict: dict[str, torch.Tensor], + token_mask_nc_dict: dict[str, torch.Tensor], + prompt_mask_nc: torch.Tensor, + label_nc_dict: dict[str, torch.Tensor], + label_weight_nc_dict: dict[str, torch.Tensor], + ) -> dict[str, torch.Tensor]: + logits_nck_dict = self.predict( + token_value_nc_dict=token_value_nc_dict, + token_mask_nc_dict=token_mask_nc_dict, + prompt_mask_nc=prompt_mask_nc, + ) + + # Compute loss + losses = {} + loss_fn = nn.CrossEntropyLoss(reduction="none") + # Make sure that label_nc_dict is created by concatenating the gene_value and metadata labels + # in the same order as the embeddings. + for key, label_nc in label_nc_dict.items(): + logits_nck = logits_nck_dict[key] + assert isinstance(logits_nck, torch.Tensor) + label_weight_nc = label_weight_nc_dict[key] + assert isinstance(label_weight_nc, torch.Tensor) + losses[key] = torch.sum( + loss_fn(logits_nck.view(label_nc.numel(), -1), label_nc.view(-1).long()) * label_weight_nc.view(-1) + ) + + loss = sum(losses[key] * self.loss_scales[key] for key in losses) + assert isinstance(loss, torch.Tensor) + losses["loss"] = loss + + return losses + + def validate( + self, + trainer: pl.Trainer, + pl_module: pl.LightningModule, + batch_idx: int, + token_value_nc_dict: dict[str, torch.Tensor], + token_mask_nc_dict: dict[str, torch.Tensor], + prompt_mask_nc: torch.Tensor, + label_nc_dict: dict[str, torch.Tensor], + label_weight_nc_dict: dict[str, torch.Tensor], + ) -> None: + n = prompt_mask_nc.shape[0] + loss_dict = self.forward( + token_value_nc_dict=token_value_nc_dict, + token_mask_nc_dict=token_mask_nc_dict, + prompt_mask_nc=prompt_mask_nc, + label_nc_dict=label_nc_dict, + label_weight_nc_dict=label_weight_nc_dict, + ) + + pl_module.log_dict(loss_dict, sync_dist=True, on_epoch=True, batch_size=n) diff --git a/cellarium/ml/utilities/data.py b/cellarium/ml/utilities/data.py index e0e37825..c6b5f59c 100644 --- a/cellarium/ml/utilities/data.py +++ b/cellarium/ml/utilities/data.py @@ -144,9 +144,9 @@ def categories_to_codes(x: pd.Series | pd.DataFrame) -> np.ndarray: Numpy array. """ if isinstance(x, pd.DataFrame): - return x.apply(lambda col: col.cat.codes).to_numpy() + return x.apply(lambda col: col.cat.codes).to_numpy(dtype=np.int32) else: - return np.asarray(x.cat.codes) + return np.asarray(x.cat.codes, dtype=np.int32) def get_categories(x: pd.Series) -> np.ndarray: diff --git a/cellarium/ml/utilities/mup.py b/cellarium/ml/utilities/mup.py index 368ad795..097b6e6c 100644 --- a/cellarium/ml/utilities/mup.py +++ b/cellarium/ml/utilities/mup.py @@ -14,6 +14,14 @@ def convert_glob_to_regex(f: str) -> re.Pattern: return re.compile(fnmatch.translate(f)) +class glob_expression_param_filter: + def __init__(self, param_filters: list[re.Pattern]) -> None: + self.param_filters = param_filters + + def __call__(self, name: str) -> bool: + return any(filter.fullmatch(name) for filter in self.param_filters) + + def make_param_filter(param_filter: str | list[str]) -> Callable[[str], bool]: """ Returns the corresponding filter for parameters for the given `param_filter`. @@ -34,10 +42,7 @@ def make_param_filter(param_filter: str | list[str]) -> Callable[[str], bool]: ) ) - def glob_expression_param_filter(name: str) -> bool: - return any(filter.fullmatch(name) for filter in param_filters) - - return glob_expression_param_filter + return glob_expression_param_filter(param_filters) class LRAdjustmentGroup: diff --git a/tests/test_cellarium_gpt.py b/tests/test_cellarium_gpt.py new file mode 100644 index 00000000..ac3d97ac --- /dev/null +++ b/tests/test_cellarium_gpt.py @@ -0,0 +1,104 @@ +# Copyright Contributors to the Cellarium project. +# SPDX-License-Identifier: BSD-3-Clause + +import os +from pathlib import Path + +import lightning.pytorch as pl +import numpy as np +import torch + +from cellarium.ml import CellariumModule +from cellarium.ml.data import PyTreeDataset, read_h5ad_file +from cellarium.ml.models import CellariumGPT +from cellarium.ml.utilities.data import categories_to_codes, collate_fn + + +def test_load_from_checkpoint_multi_device(tmp_path: Path): + adata = read_h5ad_file("https://storage.googleapis.com/dsp-cellarium-cas-public/test-data/test_0.h5ad") + n = adata.n_obs + s = 4 # number of subsampled genes + c = 5 # context size + batch_size = 2 + devices = int(os.environ.get("TEST_DEVICES", "1")) + prompt_mask_nc = np.random.choice([True, False], size=(n, c), p=[0.5, 0.5]) + query_mask_nc = (~prompt_mask_nc).astype(np.float32) + X = adata.X[:, :s].toarray() + + cell_type = categories_to_codes(adata.obs["cell_type"])[:, None] + data = { + "token_value_nc_dict": { + "gene_id": np.broadcast_to(np.arange(c), (n, c)), + "gene_value": np.concatenate([X, np.zeros((n, 1), dtype=np.float32)], axis=1), + "gene_query_mask": query_mask_nc, + "total_mrna_umis": np.broadcast_to( + np.asarray(adata.obs["total_mrna_umis"], dtype=np.float32)[:, None], (n, c) + ), + "cell_type": np.broadcast_to(cell_type, (n, c)), + }, + "token_mask_nc_dict": { + "gene_id": (np.broadcast_to(np.arange(c), (n, c)) < s).astype(np.float32), + "gene_value": (np.broadcast_to(np.arange(c), (n, c)) < s).astype(np.float32), + "gene_query_mask": (np.broadcast_to(np.arange(c), (n, c)) < s).astype(np.float32), + "total_mrna_umis": (np.broadcast_to(np.arange(c), (n, c)) < s).astype(np.float32), + "cell_type": (np.broadcast_to(np.arange(c), (n, c)) == s).astype(np.float32), + }, + "prompt_mask_nc": prompt_mask_nc, + "label_nc_dict": { + "gene_value": np.concatenate([X, np.zeros((n, 1))], axis=1), + "cell_type": np.concatenate([np.zeros((n, s)), cell_type], axis=1), + }, + "label_weight_nc_dict": { + "gene_value": (np.broadcast_to(np.arange(c), (n, c)) < s).astype(np.float32), + "cell_type": (np.broadcast_to(np.arange(c), (n, c)) == s).astype(np.float32), + }, + } + dataset = PyTreeDataset(data) + dataloader = torch.utils.data.DataLoader( + dataset, + batch_size=batch_size, + num_workers=0, + collate_fn=collate_fn, + ) + # model + model = CellariumGPT( + gene_vocab_sizes={"gene_id": c, "gene_value": 100}, + metadata_vocab_sizes={"cell_type": adata.obs["cell_type"].cat.categories.size}, + d_model=3, + d_ffn=6, + n_heads=1, + n_blocks=1, + dropout_p=0, + use_bias=False, + attention_backend="torch", + attention_softmax_fp32=True, + loss_scales={"gene_value": 0.8, "cell_type": 0.2}, + attention_logits_scale=1, + mup_base_d_model=2, + mup_base_d_ffn=4, + ) + module = CellariumModule(model=model, optim_fn=torch.optim.Adam, optim_kwargs={"lr": 1e-3, "eps": 1e-8}) + # trainer + trainer = pl.Trainer( + accelerator="cpu", + devices=devices, + max_steps=2, + default_root_dir=tmp_path, + ) + # fit + trainer.fit(module, dataloader) + + # run tests only for rank 0 + if trainer.global_rank != 0: + return + + # load model from checkpoint + ckpt_path = tmp_path / "lightning_logs/version_0/checkpoints/epoch=0-step=2.ckpt" + assert ckpt_path.is_file() + loaded_model = CellariumModule.load_from_checkpoint(ckpt_path).model + assert isinstance(loaded_model, CellariumGPT) + # assert + assert model.attention_backend == loaded_model.attention_backend + assert model.embeddings_scale == loaded_model.embeddings_scale + assert model.attention_logits_scale == loaded_model.attention_logits_scale + assert model.output_logits_scale == loaded_model.output_logits_scale