Skip to content

Commit

Permalink
Fix format issues
Browse files Browse the repository at this point in the history
  • Loading branch information
francois-rozet committed Jul 19, 2024
1 parent 96088f6 commit 3f1cc61
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 18 deletions.
24 changes: 9 additions & 15 deletions lampe/inference/dcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,12 +156,10 @@ def forward(self, theta: Tensor, x: Tensor) -> Tensor:
The scalar loss :math:`l`.
"""
theta_prime = torch.roll(theta, 1, dims=0)
theta_is = self.proposal.sample(
(
self.n_samples,
theta.shape[0],
)
)
theta_is = self.proposal.sample((
self.n_samples,
theta.shape[0],
))

log_r_all = self.estimator(
theta_all := torch.cat((torch.stack((theta_prime, theta)), theta_is)),
Expand Down Expand Up @@ -256,9 +254,7 @@ def __init__(
self.forward = self._forward_is
self.get_rank_statistics = self._get_rank_statistics_is

def rsample_and_log_prob(
self, x: Tensor, shape: Size = ()
) -> Tuple[Tensor, Tensor]:
def rsample_and_log_prob(self, x: Tensor, shape: Size = ()) -> Tuple[Tensor, Tensor]:
r"""
Arguments:
x: The observation :math:`x`, with shape :math:`(*, L)`.
Expand Down Expand Up @@ -300,12 +296,10 @@ def _forward_is(self, theta: Tensor, x: Tensor) -> Tensor:
Returns:
The scalar loss :math:`l`.
"""
theta_is = self.proposal.sample(
(
self.n_samples,
theta.shape[0],
)
)
theta_is = self.proposal.sample((
self.n_samples,
theta.shape[0],
))
log_p = self.estimator(torch.cat((theta.unsqueeze(0), theta_is)), x)
lr = self.regularizer(x, log_p, theta_is)

Expand Down
4 changes: 1 addition & 3 deletions tests/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,7 @@ def test_AMNRE():

assert log_r.shape == (256,)

grad = torch.autograd.functional.jacobian(
lambda theta: estimator(theta, x, b).sum(), theta
)
grad = torch.autograd.functional.jacobian(lambda theta: estimator(theta, x, b).sum(), theta)

assert (grad[~b] == 0).all()

Expand Down

0 comments on commit 3f1cc61

Please sign in to comment.