Skip to content

Commit

Permalink
Improve media logging support in WandbLogger (#18164)
Browse files Browse the repository at this point in the history
Co-authored-by: Carlos Mocholí <[email protected]>
Co-authored-by: Bharat Ramanathan <[email protected]>
Co-authored-by: Jirka Borovec <[email protected]>
Co-authored-by: awaelchli <[email protected]>
  • Loading branch information
5 people authored Nov 15, 2023
1 parent 54593b0 commit 008a83e
Show file tree
Hide file tree
Showing 3 changed files with 119 additions and 1 deletion.
52 changes: 52 additions & 0 deletions src/lightning/pytorch/loggers/wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,6 +485,58 @@ def log_image(self, key: str, images: List[Any], step: Optional[int] = None, **k
metrics = {key: [wandb.Image(img, **kwarg) for img, kwarg in zip(images, kwarg_list)]}
self.log_metrics(metrics, step) # type: ignore[arg-type]

@rank_zero_only
def log_audio(self, key: str, audios: List[Any], step: Optional[int] = None, **kwargs: Any) -> None:
r"""Log audios (numpy arrays, or file paths).
Args:
key: The key to be used for logging the audio files
audios: The list of audio file paths, or numpy arrays to be logged
step: The step number to be used for logging the audio files
\**kwargs: Optional kwargs are lists passed to each ``Wandb.Audio`` instance (ex: caption, sample_rate).
Optional kwargs are lists passed to each audio (ex: caption, sample_rate).
"""
if not isinstance(audios, list):
raise TypeError(f'Expected a list as "audios", found {type(audios)}')
n = len(audios)
for k, v in kwargs.items():
if len(v) != n:
raise ValueError(f"Expected {n} items but only found {len(v)} for {k}")
kwarg_list = [{k: kwargs[k][i] for k in kwargs} for i in range(n)]

import wandb

metrics = {key: [wandb.Audio(audio, **kwarg) for audio, kwarg in zip(audios, kwarg_list)]}
self.log_metrics(metrics, step) # type: ignore[arg-type]

@rank_zero_only
def log_video(self, key: str, videos: List[Any], step: Optional[int] = None, **kwargs: Any) -> None:
"""Log videos (numpy arrays, or file paths).
Args:
key: The key to be used for logging the video files
videos: The list of video file paths, or numpy arrays to be logged
step: The step number to be used for logging the video files
**kwargs: Optional kwargs are lists passed to each Wandb.Video instance (ex: caption, fps, format).
Optional kwargs are lists passed to each video (ex: caption, fps, format).
"""
if not isinstance(videos, list):
raise TypeError(f'Expected a list as "videos", found {type(videos)}')
n = len(videos)
for k, v in kwargs.items():
if len(v) != n:
raise ValueError(f"Expected {n} items but only found {len(v)} for {k}")
kwarg_list = [{k: kwargs[k][i] for k in kwargs} for i in range(n)]

import wandb

metrics = {key: [wandb.Video(video, **kwarg) for video, kwarg in zip(videos, kwarg_list)]}
self.log_metrics(metrics, step) # type: ignore[arg-type]

@property
@override
def save_dir(self) -> Optional[str]:
Expand Down
10 changes: 9 additions & 1 deletion tests/tests_pytorch/loggers/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,13 @@ class RunType: # to make isinstance checks pass
pass

run_mock = Mock(
spec=RunType, log=Mock(), config=Mock(), watch=Mock(), log_artifact=Mock(), use_artifact=Mock(), id="run_id"
spec=RunType,
log=Mock(),
config=Mock(),
watch=Mock(),
log_artifact=Mock(),
use_artifact=Mock(),
id="run_id",
)

wandb = ModuleType("wandb")
Expand All @@ -58,6 +64,8 @@ class RunType: # to make isinstance checks pass
wandb.Api = Mock()
wandb.Artifact = Mock()
wandb.Image = Mock()
wandb.Audio = Mock()
wandb.Video = Mock()
wandb.Table = Mock()
monkeypatch.setitem(sys.modules, "wandb", wandb)

Expand Down
58 changes: 58 additions & 0 deletions tests/tests_pytorch/loggers/test_wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,64 @@ def test_wandb_log_media(wandb_mock, tmp_path):
with pytest.raises(ValueError, match="Expected 2 items but only found 1 for caption"):
logger.log_image(key="samples", images=["1.jpg", "2.jpg"], caption=["caption 1"])

# test log_audio
wandb_mock.init().log.reset_mock()
logger.log_audio(key="samples", audios=["1.mp3", "2.mp3"])
wandb_mock.Audio.assert_called_with("2.mp3")
wandb_mock.init().log.assert_called_once_with({"samples": [wandb_mock.Audio(), wandb_mock.Audio()]})

# test log_audio with step
wandb_mock.init().log.reset_mock()
logger.log_audio(key="samples", audios=["1.mp3", "2.mp3"], step=5)
wandb_mock.Audio.assert_called_with("2.mp3")
wandb_mock.init().log.assert_called_once_with(
{"samples": [wandb_mock.Audio(), wandb_mock.Audio()], "trainer/global_step": 5}
)

# test log_audio with captions
wandb_mock.init().log.reset_mock()
wandb_mock.Audio.reset_mock()
logger.log_audio(key="samples", audios=["1.mp3", "2.mp3"], caption=["caption 1", "caption 2"])
wandb_mock.Audio.assert_called_with("2.mp3", caption="caption 2")
wandb_mock.init().log.assert_called_once_with({"samples": [wandb_mock.Audio(), wandb_mock.Audio()]})

# test log_audio without a list
with pytest.raises(TypeError, match="""Expected a list as "audios", found <class 'str'>"""):
logger.log_audio(key="samples", audios="1.mp3")

# test log_audio with wrong number of captions
with pytest.raises(ValueError, match="Expected 2 items but only found 1 for caption"):
logger.log_audio(key="samples", audios=["1.mp3", "2.mp3"], caption=["caption 1"])

# test log_video
wandb_mock.init().log.reset_mock()
logger.log_video(key="samples", videos=["1.mp4", "2.mp4"])
wandb_mock.Video.assert_called_with("2.mp4")
wandb_mock.init().log.assert_called_once_with({"samples": [wandb_mock.Video(), wandb_mock.Video()]})

# test log_video with step
wandb_mock.init().log.reset_mock()
logger.log_video(key="samples", videos=["1.mp4", "2.mp4"], step=5)
wandb_mock.Video.assert_called_with("2.mp4")
wandb_mock.init().log.assert_called_once_with(
{"samples": [wandb_mock.Video(), wandb_mock.Video()], "trainer/global_step": 5}
)

# test log_video with captions
wandb_mock.init().log.reset_mock()
wandb_mock.Video.reset_mock()
logger.log_video(key="samples", videos=["1.mp4", "2.mp4"], caption=["caption 1", "caption 2"])
wandb_mock.Video.assert_called_with("2.mp4", caption="caption 2")
wandb_mock.init().log.assert_called_once_with({"samples": [wandb_mock.Video(), wandb_mock.Video()]})

# test log_video without a list
with pytest.raises(TypeError, match="""Expected a list as "videos", found <class 'str'>"""):
logger.log_video(key="samples", videos="1.mp4")

# test log_video with wrong number of captions
with pytest.raises(ValueError, match="Expected 2 items but only found 1 for caption"):
logger.log_video(key="samples", videos=["1.mp4", "2.mp4"], caption=["caption 1"])

# test log_table
wandb_mock.Table.reset_mock()
wandb_mock.init().log.reset_mock()
Expand Down

0 comments on commit 008a83e

Please sign in to comment.