Skip to content

Commit

Permalink
fix: convert edge indices to long tensors
Browse files Browse the repository at this point in the history
  • Loading branch information
macwiatrak committed Feb 1, 2025
1 parent e20ea14 commit 1af18b6
Showing 1 changed file with 11 additions and 11 deletions.
22 changes: 11 additions & 11 deletions bactgraph/modeling/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,30 +127,30 @@ def __init__(self, config: dict):

def forward(self, x: torch.Tensor, edge_index: torch.Tensor):
"""Expects a PyG data object with data.x (node features) and data.edge_index (graph connectivity)."""
logits = self.gat_module(x, edge_index) + self.bias
logits = self.gat_module(x, edge_index).squeeze() + self.bias
return F.softplus(logits)

def training_step(self, batch, batch_idx):
"""Training step."""
x_batch, edge_index_batch, y = batch
x, edge_index, _ = batch_into_single_graph(x_batch, edge_index_batch)
x, edge_index, _ = batch_into_single_graph(x_batch, edge_index_batch.type(torch.long))
print(x.shape, edge_index.shape, y.shape)
preds = self.forward(x, edge_index)
# Squeeze if your output_dim=1 to match target shape
loss = F.mse_loss(preds.squeeze(), y.view(-1))
loss = F.mse_loss(preds, y.view(-1))
self.log("train_loss", loss, on_step=False, on_epoch=True)
return loss

def validation_step(self, batch, batch_idx):
"""Validation step"""
x_batch, edge_index_batch, y = batch
x, edge_index, _ = batch_into_single_graph(x_batch, edge_index_batch)
x, edge_index, _ = batch_into_single_graph(x_batch, edge_index_batch.type(torch.long))
print(x.shape, edge_index.shape, y.shape)
preds = self.forward(x, edge_index)

loss = F.mse_loss(preds.squeeze(), y.view(-1))
pearson = pearson_corrcoef(preds.squeeze(), y.view(-1))
r2 = r2_score(preds.squeeze(), y.view(-1))
loss = F.mse_loss(preds, y.view(-1))
pearson = pearson_corrcoef(preds, y.view(-1))
r2 = r2_score(preds, y.view(-1))

res = {"test_loss": loss, "test_pearson": pearson, "test_r2": r2}
self.log_dict(res, prog_bar=True, batch_size=self.config["batch_size"])
Expand All @@ -160,12 +160,12 @@ def validation_step(self, batch, batch_idx):
def test_step(self, batch, batch_idx) -> dict:
"""Test step."""
x_batch, edge_index_batch, y = batch
x, edge_index, _ = batch_into_single_graph(x_batch, edge_index_batch)
x, edge_index, _ = batch_into_single_graph(x_batch, edge_index_batch.type(torch.long))
preds = self.forward(x, edge_index)

loss = F.mse_loss(preds.squeeze(), y.view(-1))
pearson = pearson_corrcoef(preds.squeeze(), y.view(-1))
r2 = r2_score(preds.squeeze(), y.view(-1))
loss = F.mse_loss(preds, y.view(-1))
pearson = pearson_corrcoef(preds, y.view(-1))
r2 = r2_score(preds, y.view(-1))

res = {"test_loss": loss, "test_pearson": pearson, "test_r2": r2}
self.log_dict(res, prog_bar=True, batch_size=self.config["batch_size"])
Expand Down

0 comments on commit 1af18b6

Please sign in to comment.