Skip to content

Commit

Permalink
feat: add gene matrix parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
macwiatrak committed Feb 1, 2025
1 parent 4eb351a commit f296ab8
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 5 deletions.
10 changes: 7 additions & 3 deletions bactgraph/modeling/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,10 @@ def __init__(self, config: dict):
)

# self.linear = nn.Linear(config["output_dim"], 1)
self.bias = torch.nn.Parameter(torch.zeros(config["n_genes"]))
self.bias = torch.nn.Parameter(torch.zeros(config["n_genes"])).unsqueeze(1)
self.dropout = nn.Dropout(config["dropout"])
self.gene_matrix = nn.Parameter(torch.empty(config["n_genes"], config["output_dim"]))
nn.init.xavier_normal_(self.gene_matrix)
# self.dropout = nn.Dropout(config["dropout"])

# Learning rate (default to 1e-3 if not specified)
Expand All @@ -131,8 +134,9 @@ def forward(self, x_batch: torch.Tensor, edge_index_batch: torch.Tensor, gene_in
# batch_size = x_batch.shape[0]
# logits = self.gat_module(x, edge_index).squeeze() # + self.bias.repeat(batch_size)
last_hidden_state = self.gat_module(x, edge_index)
last_hidden_state = group_by_label(last_hidden_state, gene_indices.view(-1))
return F.softplus(last_hidden_state)
last_hidden_state = group_by_label(self.dropout(last_hidden_state), gene_indices.view(-1))
logits = torch.einsum("bnm,bm->bn", last_hidden_state, self.gene_matrix) + self.bias
return F.softplus(logits)

def training_step(self, batch, batch_idx):
"""Training step."""
Expand Down
4 changes: 2 additions & 2 deletions bactgraph/modeling/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@ def __init__(self):
test: bool = False
input_dim: int = 480
hidden_dim: int = 480
output_dim: int = 1
num_layers: int = 3
output_dim: int = 480
num_layers: int = 2
num_heads: int = 4
dropout: float = 0.2
lr: float = 0.001
Expand Down

0 comments on commit f296ab8

Please sign in to comment.