Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement contrastive learning model and transforms #195

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from

Conversation

bricewang
Copy link
Contributor

@bricewang bricewang commented May 31, 2024

Adds contrastive multilayer perceptron and NT cross-entropy loss

Adds Duplicate, BinomialResample, Dropout, GaussianNoise transforms

Adds contrastive_mlp to CLI

@bricewang bricewang self-assigned this May 31, 2024
@bricewang bricewang requested a review from ordabayevy May 31, 2024 06:50
@bricewang bricewang force-pushed the bw-contrastive-new branch from 0226328 to ede60bb Compare May 31, 2024 15:54
@bricewang bricewang requested a review from mbabadi May 31, 2024 18:09
Copy link
Contributor

@ordabayevy ordabayevy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great! Left my initial comments. Can you also add the test in test_cli::test_cpu_multi_device and a loading from a checkpoint test similar to one in https://github.com/cellarium-ai/cellarium-ml/blob/main/tests/test_geneformer.py

"cellarium.ml.models.ContrastiveMLP",
link_arguments=[
LinkArguments("data", "model.model.init_args.g_genes", compute_n_vars),
LinkArguments("trainer.devices", "model.model.init_args.world_size", None, "parse"),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would recommend instead of making the world_size a model parameter compute it dynamically in the forward method like here: https://github.com/cellarium-ai/cellarium-ml/blob/main/cellarium/ml/models/onepass_mean_var_std.py#L85

The reason being is that world_size is not a property of the model but of the training procedure. For example, you could train a model with 4 GPUs and use only 1 GPU at inference time, or resume training with 2 GPUs etc.

],
trainer_defaults={
"max_epochs": 20,
"strategy": {"class_path": "lightning.pytorch.strategies.DDPStrategy"},
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the DDPStrategy default can be removed because it is applied by default when number of devices is greater than 1.

from cellarium.ml.models.geneformer import Geneformer
from cellarium.ml.models.incremental_pca import IncrementalPCA
from cellarium.ml.models.logistic_regression import LogisticRegression
from cellarium.ml.models.model import CellariumModel, PredictMixin, ValidateMixin
from cellarium.ml.models.mu_linear import MuLinear, abcdParameter
from cellarium.ml.models.nt_xent import NT_Xent
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Refactor: I suggest having a cellarium.ml.losses folder and add our losses there.


def __init__(
self,
g_genes: int,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Naming convention: at some point we switched from g_genes naming convention to n_obs. You can also switch it to n_obs to be consistent with other models.

):
super(ContrastiveMLP, self).__init__()

layer_list: List[nn.Module] = []
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor comment: you can create an empty nn.Sequential() and append to it directly:

self.layers = nn.Sequential()
self.layers.append(module)

loss = self.Xent_loss(z1, z2)
return {"loss": loss}

def predict(self, x_ng: torch.Tensor, **kwargs: Any):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor comment: **kwargs can be removed, right?

Binomially resampled gene counts.
"""
p_binom_ng = Uniform(self.p_binom_min, self.p_binom_max).sample(x_ng.shape).type_as(x_ng)
p_apply_n = Bernoulli(probs=self.p_apply).sample(x_ng.shape[:1]).type_as(x_ng).bool()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe rename p_apply_n to something else, maybe like apply_mask_n? Because it is not a probability anymore.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also fixed a bug where x_aug and x_ng were reversed in torch.where

Upper bound on binomial distribution parameter.
"""

def __init__(self, p_binom_min, p_binom_max, p_apply):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add type hints to __init__.

Returns:
Binomially resampled gene counts.
"""
p_binom_ng = Uniform(self.p_binom_min, self.p_binom_max).sample(x_ng.shape).type_as(x_ng)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you using .type_as here to change the device?

Uniform uses the device of the self.p_binom_min. I think if it is of float type then it might be converted to torch.Tensor and be on CPU. Then Uniform.sample will sample on CPU which might be slow. In this case it might be better to convert self.p_binom_min to tensor yourself and cast to GPU device so that sampling step will be faster.


class Randomize(nn.Module):
"""
Randomizely applies transform with probability p;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Randomizely -> Randomly?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alternatively I removed the entire class, as implementation constraints dictated that I determine whether or not to apply transform in the transform itself (i.e. p_apply). In future, this logic could also be pulled out into some inheritable entity.

@bricewang bricewang force-pushed the bw-contrastive-new branch 2 times, most recently from f33bb07 to 8474e25 Compare August 14, 2024 00:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants