From 8f89738f53d46dd6ad7e265150d7c87de46f1547 Mon Sep 17 00:00:00 2001 From: OlaRonning Date: Mon, 2 Dec 2024 09:42:07 +0100 Subject: [PATCH 1/2] added lognorm terms for high conc sbvm --- pyro/distributions/sine_bivariate_von_mises.py | 18 ++++++++++++------ .../test_sine_bivariate_von_mises.py | 12 ++++++++++++ 2 files changed, 24 insertions(+), 6 deletions(-) diff --git a/pyro/distributions/sine_bivariate_von_mises.py b/pyro/distributions/sine_bivariate_von_mises.py index ea0a7e05e9..7bd4ed0272 100644 --- a/pyro/distributions/sine_bivariate_von_mises.py +++ b/pyro/distributions/sine_bivariate_von_mises.py @@ -35,7 +35,6 @@ class SineBivariateVonMises(TorchDistribution): This distribution is a submodel of the Bivariate von Mises distribution, called the Sine Distribution [2] in directional statistics. - This distribution is helpful for modeling coupled angles such as torsion angles in peptide chains. To infer parameters, use :class:`~pyro.infer.NUTS` or :class:`~pyro.infer.HMC` with priors that avoid parameterizations where the distribution becomes bimodal; see note below. @@ -44,10 +43,12 @@ class SineBivariateVonMises(TorchDistribution): .. math:: - \frac{\rho}{\kappa_1\kappa_2} \rightarrow 1 + \frac{\rho^2}{\kappa_1\kappa_2} \rightarrow 1 - because the distribution becomes increasingly bimodal. To avoid bimodality use the `weighted_correlation` - parameter with a skew away from one (e.g., Beta(1,3)). The `weighted_correlation` should be in [0,1]. + because the distribution becomes increasingly bimodal. To avoid inefficient sampling use the + `weighted_correlation` parameter with a skew away from one (e.g., + `TransformedDistribution(Beta(5,5), AffineTransform(loc=-1, scale=2))`). The `weighted_correlation` + should be in [-1,1]. .. note:: The correlation and weighted_correlation params are mutually exclusive. @@ -65,7 +66,7 @@ class SineBivariateVonMises(TorchDistribution): :param torch.Tensor psi_concentration: concentration of second angle :param torch.Tensor correlation: correlation between the two angles :param torch.Tensor weighted_correlation: set correlation to weighted_corr * sqrt(phi_conc*psi_conc) - to avoid bimodality (see note). The `weighted_correlation` should be in [0,1]. + to avoid bimodality (see note). The `weighted_correlation` should be in [-1,1]. """ arg_constraints = { @@ -139,7 +140,12 @@ def norm_const(self): + m * torch.log((corr**2).clamp(min=tiny)) - m * torch.log(4 * torch.prod(conc, dim=-1)) ) - fs += log_I1(m.max(), conc, 51).sum(-1) + num_I1terms = torch.maximum( + torch.tensor(501), torch.max(self.phi_concentration) + torch.max(self.psi_concentration) + ).int() + + fs += log_I1(m.max(), conc, num_I1terms).sum(-1) + mfs = fs.max() norm_const = 2 * torch.log(torch.tensor(2 * pi)) + mfs + (fs - mfs).logsumexp(0) return norm_const.reshape(self.phi_loc.shape) diff --git a/tests/distributions/test_sine_bivariate_von_mises.py b/tests/distributions/test_sine_bivariate_von_mises.py index bd676a6690..99ad5f136f 100644 --- a/tests/distributions/test_sine_bivariate_von_mises.py +++ b/tests/distributions/test_sine_bivariate_von_mises.py @@ -130,3 +130,15 @@ def guide(data): ) # k == 'corr' assert_equal(expected[k].squeeze(), actual.squeeze(), 9e-2) + +@pytest.mark.parametrize("conc", [1.0, 10.0, 1000.0, 10000.0]) +def test_sine_bivariate_von_mises_norm(conc): + dist = SineBivariateVonMises(0, 0, conc, conc, 0.0) + num_samples = 500 + x = torch.linspace(-torch.pi, torch.pi, num_samples) + y = torch.linspace(-torch.pi, torch.pi, num_samples) + mesh = torch.stack(torch.meshgrid(x, y, indexing='ij'), axis=-1) + integral_torus = ( + torch.exp(dist.log_prob(mesh)) * (2 * torch.pi) ** 2 / num_samples**2 + ).sum() + assert torch.allclose(integral_torus, torch.tensor(1.0), rtol=1e-2) \ No newline at end of file From 5aa59fbad75785125326c76bd5749f096d090879 Mon Sep 17 00:00:00 2001 From: OlaRonning Date: Mon, 2 Dec 2024 09:54:54 +0100 Subject: [PATCH 2/2] lint --- pyro/distributions/sine_bivariate_von_mises.py | 9 +++++---- tests/distributions/test_sine_bivariate_von_mises.py | 5 +++-- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/pyro/distributions/sine_bivariate_von_mises.py b/pyro/distributions/sine_bivariate_von_mises.py index 7bd4ed0272..40be29ec9a 100644 --- a/pyro/distributions/sine_bivariate_von_mises.py +++ b/pyro/distributions/sine_bivariate_von_mises.py @@ -45,9 +45,9 @@ class SineBivariateVonMises(TorchDistribution): \frac{\rho^2}{\kappa_1\kappa_2} \rightarrow 1 - because the distribution becomes increasingly bimodal. To avoid inefficient sampling use the - `weighted_correlation` parameter with a skew away from one (e.g., - `TransformedDistribution(Beta(5,5), AffineTransform(loc=-1, scale=2))`). The `weighted_correlation` + because the distribution becomes increasingly bimodal. To avoid inefficient sampling use the + `weighted_correlation` parameter with a skew away from one (e.g., + `TransformedDistribution(Beta(5,5), AffineTransform(loc=-1, scale=2))`). The `weighted_correlation` should be in [-1,1]. .. note:: The correlation and weighted_correlation params are mutually exclusive. @@ -141,7 +141,8 @@ def norm_const(self): - m * torch.log(4 * torch.prod(conc, dim=-1)) ) num_I1terms = torch.maximum( - torch.tensor(501), torch.max(self.phi_concentration) + torch.max(self.psi_concentration) + torch.tensor(501), + torch.max(self.phi_concentration) + torch.max(self.psi_concentration), ).int() fs += log_I1(m.max(), conc, num_I1terms).sum(-1) diff --git a/tests/distributions/test_sine_bivariate_von_mises.py b/tests/distributions/test_sine_bivariate_von_mises.py index 99ad5f136f..6212220dcb 100644 --- a/tests/distributions/test_sine_bivariate_von_mises.py +++ b/tests/distributions/test_sine_bivariate_von_mises.py @@ -131,14 +131,15 @@ def guide(data): assert_equal(expected[k].squeeze(), actual.squeeze(), 9e-2) + @pytest.mark.parametrize("conc", [1.0, 10.0, 1000.0, 10000.0]) def test_sine_bivariate_von_mises_norm(conc): dist = SineBivariateVonMises(0, 0, conc, conc, 0.0) num_samples = 500 x = torch.linspace(-torch.pi, torch.pi, num_samples) y = torch.linspace(-torch.pi, torch.pi, num_samples) - mesh = torch.stack(torch.meshgrid(x, y, indexing='ij'), axis=-1) + mesh = torch.stack(torch.meshgrid(x, y, indexing="ij"), axis=-1) integral_torus = ( torch.exp(dist.log_prob(mesh)) * (2 * torch.pi) ** 2 / num_samples**2 ).sum() - assert torch.allclose(integral_torus, torch.tensor(1.0), rtol=1e-2) \ No newline at end of file + assert torch.allclose(integral_torus, torch.tensor(1.0), rtol=1e-2)