-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement contrastive learning model and transforms #195
base: main
Are you sure you want to change the base?
Conversation
0226328
to
ede60bb
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great! Left my initial comments. Can you also add the test in test_cli::test_cpu_multi_device
and a loading from a checkpoint test similar to one in https://github.com/cellarium-ai/cellarium-ml/blob/main/tests/test_geneformer.py
cellarium/ml/cli.py
Outdated
"cellarium.ml.models.ContrastiveMLP", | ||
link_arguments=[ | ||
LinkArguments("data", "model.model.init_args.g_genes", compute_n_vars), | ||
LinkArguments("trainer.devices", "model.model.init_args.world_size", None, "parse"), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would recommend instead of making the world_size
a model parameter compute it dynamically in the forward method like here: https://github.com/cellarium-ai/cellarium-ml/blob/main/cellarium/ml/models/onepass_mean_var_std.py#L85
The reason being is that world_size
is not a property of the model but of the training procedure. For example, you could train a model with 4 GPUs and use only 1 GPU at inference time, or resume training with 2 GPUs etc.
cellarium/ml/cli.py
Outdated
], | ||
trainer_defaults={ | ||
"max_epochs": 20, | ||
"strategy": {"class_path": "lightning.pytorch.strategies.DDPStrategy"}, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the DDPStrategy
default can be removed because it is applied by default when number of devices is greater than 1.
cellarium/ml/models/__init__.py
Outdated
from cellarium.ml.models.geneformer import Geneformer | ||
from cellarium.ml.models.incremental_pca import IncrementalPCA | ||
from cellarium.ml.models.logistic_regression import LogisticRegression | ||
from cellarium.ml.models.model import CellariumModel, PredictMixin, ValidateMixin | ||
from cellarium.ml.models.mu_linear import MuLinear, abcdParameter | ||
from cellarium.ml.models.nt_xent import NT_Xent |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Refactor: I suggest having a cellarium.ml.losses
folder and add our losses there.
|
||
def __init__( | ||
self, | ||
g_genes: int, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Naming convention: at some point we switched from g_genes
naming convention to n_obs
. You can also switch it to n_obs
to be consistent with other models.
): | ||
super(ContrastiveMLP, self).__init__() | ||
|
||
layer_list: List[nn.Module] = [] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Minor comment: you can create an empty nn.Sequential()
and append to it directly:
self.layers = nn.Sequential()
self.layers.append(module)
loss = self.Xent_loss(z1, z2) | ||
return {"loss": loss} | ||
|
||
def predict(self, x_ng: torch.Tensor, **kwargs: Any): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Minor comment: **kwargs
can be removed, right?
Binomially resampled gene counts. | ||
""" | ||
p_binom_ng = Uniform(self.p_binom_min, self.p_binom_max).sample(x_ng.shape).type_as(x_ng) | ||
p_apply_n = Bernoulli(probs=self.p_apply).sample(x_ng.shape[:1]).type_as(x_ng).bool() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe rename p_apply_n
to something else, maybe like apply_mask_n
? Because it is not a probability anymore.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also fixed a bug where x_aug
and x_ng
were reversed in torch.where
Upper bound on binomial distribution parameter. | ||
""" | ||
|
||
def __init__(self, p_binom_min, p_binom_max, p_apply): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add type hints to __init__
.
Returns: | ||
Binomially resampled gene counts. | ||
""" | ||
p_binom_ng = Uniform(self.p_binom_min, self.p_binom_max).sample(x_ng.shape).type_as(x_ng) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are you using .type_as
here to change the device?
Uniform
uses the device of the self.p_binom_min
. I think if it is of float
type then it might be converted to torch.Tensor
and be on CPU. Then Uniform.sample
will sample on CPU which might be slow. In this case it might be better to convert self.p_binom_min
to tensor yourself and cast to GPU device so that sampling step will be faster.
cellarium/ml/transforms/randomize.py
Outdated
|
||
class Randomize(nn.Module): | ||
""" | ||
Randomizely applies transform with probability p; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Randomizely -> Randomly?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Alternatively I removed the entire class, as implementation constraints dictated that I determine whether or not to apply transform in the transform itself (i.e. p_apply
). In future, this logic could also be pulled out into some inheritable entity.
f33bb07
to
8474e25
Compare
…update CLI, remove Randomize
8474e25
to
39e2134
Compare
Fix prediction_writer loop
19c69b5
to
f7ed8a5
Compare
Adds contrastive multilayer perceptron and NT cross-entropy loss
Adds Duplicate, BinomialResample, Dropout, GaussianNoise transforms
Adds contrastive_mlp to CLI