-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
CellariumGPT #129
base: main
Are you sure you want to change the base?
CellariumGPT #129
Conversation
c82fbeb
to
43ed76b
Compare
43ed76b
to
7008506
Compare
# Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block | ||
p.data.normal_(mean=0.0, std=(self.initializer_range / math.sqrt(2 * self.gpt_model.n_blocks))) | ||
|
||
def tokenize( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
documentation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you also separately take obs_total_mrna_umis_n
(to be concatenated to values_nc
) and out_total_mrna_umis_n
to generate a token specifying the library size we need the readout to be at? This construct can be used for model training with downsampling. We could use Brice's logic to Duplicate
the counts, downsample the first copy to use for observation, and use the second copy for readout.
Another thought is to have multiple tokenize
methods. For example:
-- generate_observation_tokens
-- generate_output_tokens
-- generate_register_tokens
(later)
-- generate_metadata_tokens
(later)
In this construct, generate_observation_tokens
takes x_ng
, obs_total_mrna_umis_n
as input and generates a bunch of tokens. generate_output_tokens
, for now, just takes out_total_mrna_umis_n
and generates a single token.
All of the tokens are then concatenated and given to an embedding layer. The tokens should also be equipped with metadata to specify how they should be embedded ...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Made an issue #131 about this.
cellarium/ml/models/cellarium_gpt.py
Outdated
# - torch.lgamma(self.theta) | ||
# - torch.lgamma(value + 1) | ||
# ) | ||
delta = torch.where( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add documentation:
"For large "theta", we can use Stirling's asymptotic approximation (see https://en.wikipedia.org/wiki/Gamma_function#Log-gamma_function), which is numerically more stable than PyTorch's implementation of lgamma."
Actually, the condition value / theta < 1e-2 may not be needed. Let's make an issue for me to investiagte.
if (trainer.global_step + 1) % (trainer.log_every_n_steps * 10) != 0: # type: ignore[attr-defined] | ||
return | ||
|
||
import matplotlib.pyplot as plt |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suggest refactoring these plotting functions out to keep on_batch_end
decluttered.
@register_model | ||
def cellarium_gpt(args: ArgsType = None) -> None: | ||
r""" | ||
CLI to run the :class:`cellarium.ml.models.CellariumGPT` model. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add example CLI command.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some small changes here and there.
Generating sample cellarium_gpt.yaml config file
return mu_nc, theta_nc | ||
|
||
|
||
class CellariumGPT(CellariumModel): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you make it a PredictMixin
and implement predict()
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Although I'm thinking that, analogous to Geneformer, predict()
would return the gene embedding vectors. But I'm not sure that exactly makes sense here... you might also imagine that "predict" would return the negative binomial distributions.
My interest was in having a common sort of interface I could use to extract gene embeddings, whether or not the model was a Geneformer model or a CellariumGPT model. I was previously using .predict()
to do this from Geneformer, so I thought it would be nice if CellariumGPT worked the same way.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See #152 if you're interested in this idea
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've been locally implementing different versions of predict
method depending on what I wanted to analyze. Since I was changing it a lot decided not to add it here. But it should be something similar to Geneformer predict that returns a dictionary and we can have boolean flags that can be used to control what it returns.
No description provided.