diff --git a/casanovo/denovo/model.py b/casanovo/denovo/model.py index 88f3aaca..716dd747 100644 --- a/casanovo/denovo/model.py +++ b/casanovo/denovo/model.py @@ -2,6 +2,7 @@ import collections import heapq +import itertools import logging import warnings from pathlib import Path @@ -1062,10 +1063,11 @@ def predict_step( batch_aa_scores, psm_batch[3], ): + spectrum_i = tuple(spectrum_i) predictions_all[spectrum_i].append( ms_io.PepSpecMatch( sequence=peptide, - spectrum_id=tuple(spectrum_i), + spectrum_id=spectrum_i, peptide_score=peptide_score, charge=int(charge), calc_mz=self.peptide_mass_calculator.mass( @@ -1079,16 +1081,20 @@ def predict_step( ) ) # Filter the top-scoring prediction(s) for each spectrum. - predictions = [ - *( - sorted( - spectrum_predictions, - key=lambda p: p.peptide_score, - reverse=True, - )[: self.top_match] - for spectrum_predictions in predictions_all.values() + predictions = list( + itertools.chain.from_iterable( + [ + *( + sorted( + spectrum_predictions, + key=lambda p: p.peptide_score, + reverse=True, + )[: self.top_match] + for spectrum_predictions in predictions_all.values() + ) + ] ) - ] + ) return predictions def on_predict_batch_end(