Skip to content

Commit

Permalink
doc: clarified input variables of multiscale flow (#64)
Browse files Browse the repository at this point in the history
  • Loading branch information
VincentStimper committed Aug 25, 2024
1 parent b878a7a commit c6616b1
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions normflows/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,7 @@ def __init__(self, q0, flows, merges, transform=None, class_cond=True):
Args:
q0: List of base distribution
flows: List of list of flows for each level
flows: List of flows for each level
merges: List of merge/split operations (forward pass must do merge)
transform: Initial transformation of inputs
class_cond: Flag, indicated whether model has class conditional
Expand All @@ -478,11 +478,11 @@ def __init__(self, q0, flows, merges, transform=None, class_cond=True):
self.class_cond = class_cond

def forward_kld(self, x, y=None):
"""Estimates forward KL divergence, see see [arXiv 1912.02762](https://arxiv.org/abs/1912.02762)
"""Estimates forward KL divergence, see [arXiv 1912.02762](https://arxiv.org/abs/1912.02762)
Args:
x: Batch sampled from target distribution
y: Batch of targets, if applicable
y: Batch of classes to condition on, if applicable
Returns:
Estimate of forward KL divergence averaged over batch
Expand All @@ -494,7 +494,7 @@ def forward(self, x, y=None):
Args:
x: Batch of data
y: Batch of targets, if applicable
y: Batch of classes to condition on, if applicable
Returns:
Negative log-likelihood of the batch
Expand Down

0 comments on commit c6616b1

Please sign in to comment.