Skip to content

Commit

Permalink
fix: fix fetching protein embeddings in the dataset.py
Browse files Browse the repository at this point in the history
  • Loading branch information
macwiatrak committed Feb 1, 2025
1 parent 306e38e commit e20ea14
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions bactgraph/modeling/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def forward(self, x: torch.Tensor, edge_index: torch.Tensor):
def training_step(self, batch, batch_idx):
"""Training step."""
x_batch, edge_index_batch, y = batch
x, edge_index = batch_into_single_graph(x_batch, edge_index_batch)
x, edge_index, _ = batch_into_single_graph(x_batch, edge_index_batch)
print(x.shape, edge_index.shape, y.shape)
preds = self.forward(x, edge_index)
# Squeeze if your output_dim=1 to match target shape
Expand All @@ -144,7 +144,7 @@ def training_step(self, batch, batch_idx):
def validation_step(self, batch, batch_idx):
"""Validation step"""
x_batch, edge_index_batch, y = batch
x, edge_index = batch_into_single_graph(x_batch, edge_index_batch)
x, edge_index, _ = batch_into_single_graph(x_batch, edge_index_batch)
print(x.shape, edge_index.shape, y.shape)
preds = self.forward(x, edge_index)

Expand All @@ -160,7 +160,7 @@ def validation_step(self, batch, batch_idx):
def test_step(self, batch, batch_idx) -> dict:
"""Test step."""
x_batch, edge_index_batch, y = batch
x, edge_index = batch_into_single_graph(x_batch, edge_index_batch)
x, edge_index, _ = batch_into_single_graph(x_batch, edge_index_batch)
preds = self.forward(x, edge_index)

loss = F.mse_loss(preds.squeeze(), y.view(-1))
Expand Down

0 comments on commit e20ea14

Please sign in to comment.