Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CellariumGPT #129

Open
wants to merge 20 commits into
base: main
Choose a base branch
from
Open

CellariumGPT #129

wants to merge 20 commits into from

Conversation

ordabayevy
Copy link
Contributor

No description provided.

@ordabayevy ordabayevy force-pushed the cellarium-gpt branch 5 times, most recently from c82fbeb to 43ed76b Compare March 7, 2024 18:07
@ordabayevy ordabayevy requested a review from mbabadi March 7, 2024 18:10
# Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
p.data.normal_(mean=0.0, std=(self.initializer_range / math.sqrt(2 * self.gpt_model.n_blocks)))

def tokenize(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

documentation

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you also separately take obs_total_mrna_umis_n (to be concatenated to values_nc) and out_total_mrna_umis_n to generate a token specifying the library size we need the readout to be at? This construct can be used for model training with downsampling. We could use Brice's logic to Duplicate the counts, downsample the first copy to use for observation, and use the second copy for readout.

Another thought is to have multiple tokenize methods. For example:
-- generate_observation_tokens
-- generate_output_tokens
-- generate_register_tokens (later)
-- generate_metadata_tokens (later)

In this construct, generate_observation_tokens takes x_ng, obs_total_mrna_umis_n as input and generates a bunch of tokens. generate_output_tokens, for now, just takes out_total_mrna_umis_n and generates a single token.

All of the tokens are then concatenated and given to an embedding layer. The tokens should also be equipped with metadata to specify how they should be embedded ...

Copy link
Member

@mbabadi mbabadi Mar 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Made an issue #131 about this.

# - torch.lgamma(self.theta)
# - torch.lgamma(value + 1)
# )
delta = torch.where(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add documentation:

"For large "theta", we can use Stirling's asymptotic approximation (see https://en.wikipedia.org/wiki/Gamma_function#Log-gamma_function), which is numerically more stable than PyTorch's implementation of lgamma."

Actually, the condition value / theta < 1e-2 may not be needed. Let's make an issue for me to investiagte.

if (trainer.global_step + 1) % (trainer.log_every_n_steps * 10) != 0: # type: ignore[attr-defined]
return

import matplotlib.pyplot as plt
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest refactoring these plotting functions out to keep on_batch_end decluttered.

@register_model
def cellarium_gpt(args: ArgsType = None) -> None:
r"""
CLI to run the :class:`cellarium.ml.models.CellariumGPT` model.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add example CLI command.

Copy link
Member

@mbabadi mbabadi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some small changes here and there.

return mu_nc, theta_nc


class CellariumGPT(CellariumModel):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you make it a PredictMixin and implement predict() ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Although I'm thinking that, analogous to Geneformer, predict() would return the gene embedding vectors. But I'm not sure that exactly makes sense here... you might also imagine that "predict" would return the negative binomial distributions.

My interest was in having a common sort of interface I could use to extract gene embeddings, whether or not the model was a Geneformer model or a CellariumGPT model. I was previously using .predict() to do this from Geneformer, so I thought it would be nice if CellariumGPT worked the same way.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See #152 if you're interested in this idea

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've been locally implementing different versions of predict method depending on what I wanted to analyze. Since I was changing it a lot decided not to add it here. But it should be something similar to Geneformer predict that returns a dictionary and we can have boolean flags that can be used to control what it returns.

@ordabayevy ordabayevy changed the base branch from main to val-dataloader May 7, 2024 17:41
Base automatically changed from val-dataloader to main May 8, 2024 17:24
@ordabayevy ordabayevy changed the base branch from main to timer May 8, 2024 22:45
@ordabayevy ordabayevy changed the base branch from timer to main May 8, 2024 22:45
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants