diff --git a/bactgraph/modeling/dataset.py b/bactgraph/modeling/dataset.py index f1cac11..af72562 100644 --- a/bactgraph/modeling/dataset.py +++ b/bactgraph/modeling/dataset.py @@ -72,7 +72,7 @@ def __getitem__(self, idx) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: # get the expression data for the idx-th strain strain = self.strains[idx] # get protein embeddings - prot_emb = torch.tensor(self.protein_embeddings.loc[strain].values, dtype=torch.float32) + prot_emb = torch.tensor(np.stack(self.protein_embeddings.loc[strain].values), dtype=torch.float32) expr_values = torch.tensor( [self.expression_df.loc[gene, strain] for gene in self.protein_embeddings.columns], dtype=torch.float32 )