From 92efb50ecd5e4bf8e9e8a2798fca4206f985249d Mon Sep 17 00:00:00 2001 From: Lilferrit Date: Wed, 18 Sep 2024 14:58:45 -0700 Subject: [PATCH] tensorboard logger --- casanovo/denovo/model_runner.py | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/casanovo/denovo/model_runner.py b/casanovo/denovo/model_runner.py index 7016e59b..80ab3cbc 100644 --- a/casanovo/denovo/model_runner.py +++ b/casanovo/denovo/model_runner.py @@ -4,23 +4,18 @@ import glob import logging import os -import re import tempfile import warnings from pathlib import Path from typing import Iterable, List, Optional, Union from datetime import datetime -import depthcharge.masses import lightning.pytorch as pl import torch +import torch.utils.data from lightning.pytorch.strategies import DDPStrategy -from lightning.pytorch.callbacks import ( - ModelCheckpoint, - LearningRateMonitor, - EarlyStopping, -) +from lightning.pytorch.callbacks import ModelCheckpoint from lightning.pytorch.loggers import TensorBoardLogger from depthcharge.tokenizers import PeptideTokenizer @@ -161,7 +156,9 @@ def train( self.loaders.val_dataloader(), ) - def log_metrics(self, test_index: AnnotatedSpectrumIndex) -> None: + def log_metrics( + self, test_dataloader: torch.utils.data.DataLoader + ) -> None: """Log peptide precision and amino acid precision Calculate and log peptide precision and amino acid precision @@ -172,7 +169,7 @@ def log_metrics(self, test_index: AnnotatedSpectrumIndex) -> None: test_index : AnnotatedSpectrumIndex Index containing the annotated spectra used to generate model predictions - """ + model_output = [psm.sequence for psm in self.writer.psms] spectrum_annotations = [ test_index[i][4] for i in range(test_index.n_spectra) @@ -187,6 +184,9 @@ def log_metrics(self, test_index: AnnotatedSpectrumIndex) -> None: logger.info("Peptide Precision: %.2f%%", 100 * pep_precision) logger.info("Amino Acid Precision: %.2f%%", 100 * aa_precision) + """ + # TODO: Fix log_metrics, wait for eval bug fix to be merged in + return def predict( self, @@ -234,7 +234,7 @@ def predict( self.trainer.predict(self.model, self.loaders.test_dataloader()) if evaluate: - self.log_metrics(test_index) + self.log_metrics(self.loaders.test_dataloader()) def initialize_trainer(self, train: bool) -> None: """Initialize the lightning Trainer. @@ -259,11 +259,11 @@ def initialize_trainer(self, train: bool) -> None: else: devices = self.config.devices - if self.config.tb_summarywriter is not None: + # TODO: CSV logger + if self.config.tb_summarywriter: logger = TensorBoardLogger( - self.config.tb_summarywriter, - version=None, - name=f'model_{datetime.now().strftime("%Y%m%d_%H%M")}', + self.output_dir, + version="tensorboard", default_hp_metric=False, ) else: