Skip to content

Commit

Permalink
tensorboard logger
Browse files Browse the repository at this point in the history
  • Loading branch information
Lilferrit committed Sep 18, 2024
1 parent c243c64 commit 92efb50
Showing 1 changed file with 14 additions and 14 deletions.
28 changes: 14 additions & 14 deletions casanovo/denovo/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,23 +4,18 @@
import glob
import logging
import os
import re
import tempfile
import warnings
from pathlib import Path
from typing import Iterable, List, Optional, Union
from datetime import datetime

import depthcharge.masses
import lightning.pytorch as pl
import torch
import torch.utils.data

from lightning.pytorch.strategies import DDPStrategy
from lightning.pytorch.callbacks import (
ModelCheckpoint,
LearningRateMonitor,
EarlyStopping,
)
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import TensorBoardLogger

from depthcharge.tokenizers import PeptideTokenizer
Expand Down Expand Up @@ -161,7 +156,9 @@ def train(
self.loaders.val_dataloader(),
)

def log_metrics(self, test_index: AnnotatedSpectrumIndex) -> None:
def log_metrics(
self, test_dataloader: torch.utils.data.DataLoader
) -> None:
"""Log peptide precision and amino acid precision
Calculate and log peptide precision and amino acid precision
Expand All @@ -172,7 +169,7 @@ def log_metrics(self, test_index: AnnotatedSpectrumIndex) -> None:
test_index : AnnotatedSpectrumIndex
Index containing the annotated spectra used to generate model
predictions
"""
model_output = [psm.sequence for psm in self.writer.psms]
spectrum_annotations = [
test_index[i][4] for i in range(test_index.n_spectra)
Expand All @@ -187,6 +184,9 @@ def log_metrics(self, test_index: AnnotatedSpectrumIndex) -> None:
logger.info("Peptide Precision: %.2f%%", 100 * pep_precision)
logger.info("Amino Acid Precision: %.2f%%", 100 * aa_precision)
"""
# TODO: Fix log_metrics, wait for eval bug fix to be merged in
return

def predict(
self,
Expand Down Expand Up @@ -234,7 +234,7 @@ def predict(
self.trainer.predict(self.model, self.loaders.test_dataloader())

if evaluate:
self.log_metrics(test_index)
self.log_metrics(self.loaders.test_dataloader())

def initialize_trainer(self, train: bool) -> None:
"""Initialize the lightning Trainer.
Expand All @@ -259,11 +259,11 @@ def initialize_trainer(self, train: bool) -> None:
else:
devices = self.config.devices

if self.config.tb_summarywriter is not None:
# TODO: CSV logger
if self.config.tb_summarywriter:
logger = TensorBoardLogger(
self.config.tb_summarywriter,
version=None,
name=f'model_{datetime.now().strftime("%Y%m%d_%H%M")}',
self.output_dir,
version="tensorboard",
default_hp_metric=False,
)
else:
Expand Down

0 comments on commit 92efb50

Please sign in to comment.