Skip to content

Commit

Permalink
🩹 Fix documentation code links
Browse files Browse the repository at this point in the history
  • Loading branch information
francois-rozet committed Sep 18, 2022
1 parent 1e7604c commit 62e2859
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 14 deletions.
26 changes: 13 additions & 13 deletions lampe/diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from .utils import gridapply


@torch.no_grad()
def expected_coverage_mc(
posterior: Callable[[Tensor], Distribution],
pairs: Iterable[Tuple[Tensor, Tensor]],
Expand Down Expand Up @@ -60,14 +59,15 @@ def expected_coverage_mc(

ranks = []

for theta, x in tqdm(pairs, unit='pair'):
dist = posterior(x)
with torch.no_grad():
for theta, x in tqdm(pairs, unit='pair'):
dist = posterior(x)

samples = dist.sample((n,))
mask = dist.log_prob(theta) < dist.log_prob(samples)
rank = mask.sum() / mask.numel()
samples = dist.sample((n,))
mask = dist.log_prob(theta) < dist.log_prob(samples)
rank = mask.sum() / mask.numel()

ranks.append(rank)
ranks.append(rank)

ranks = torch.stack(ranks).cpu()
ranks = torch.cat((ranks, torch.tensor([0.0, 1.0])))
Expand All @@ -78,7 +78,6 @@ def expected_coverage_mc(
)


@torch.no_grad()
def expected_coverage_ni(
posterior: Callable[[Tensor, Tensor], Tensor],
pairs: Iterable[Tuple[Tensor, Tensor]],
Expand Down Expand Up @@ -113,12 +112,13 @@ def expected_coverage_ni(

ranks = []

for theta, x in tqdm(pairs, unit='pair'):
_, log_probs = gridapply(lambda theta: posterior(theta, x), domain, **kwargs)
mask = posterior(theta, x) < log_probs
rank = log_probs[mask].logsumexp(dim=0) - log_probs.flatten().logsumexp(dim=0)
with torch.no_grad():
for theta, x in tqdm(pairs, unit='pair'):
_, log_probs = gridapply(lambda theta: posterior(theta, x), domain, **kwargs)
mask = posterior(theta, x) < log_probs
rank = log_probs[mask].logsumexp(dim=0) - log_probs.flatten().logsumexp(dim=0)

ranks.append(rank.exp())
ranks.append(rank.exp())

ranks = torch.stack(ranks).cpu()
ranks = torch.cat((ranks, torch.tensor([0.0, 1.0])))
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

setuptools.setup(
name='lampe',
version='0.5.4',
version='0.5.5',
packages=setuptools.find_packages(),
description='Likelihood-free AMortized Posterior Estimation with PyTorch',
keywords='parameter inference bayes posterior amortized likelihood ratio mcmc torch',
Expand Down

0 comments on commit 62e2859

Please sign in to comment.