Skip to content

Commit

Permalink
Merge branch 'main' into val-dataloader
Browse files Browse the repository at this point in the history
  • Loading branch information
ordabayevy committed May 7, 2024
2 parents e388fe6 + 4e73c9e commit b4d4202
Show file tree
Hide file tree
Showing 8 changed files with 184 additions and 19 deletions.
3 changes: 3 additions & 0 deletions cellarium/ml/distributed/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# Copyright Contributors to the Cellarium project.
# SPDX-License-Identifier: BSD-3-Clause

from cellarium.ml.distributed.gather import GatherLayer

__all__ = [
Expand Down
8 changes: 4 additions & 4 deletions cellarium/ml/distributed/gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def forward(ctx, input: torch.Tensor) -> tuple[torch.Tensor, ...]: # type: igno
return tuple(output)

@staticmethod
def backward(ctx, *grads) -> torch.Tensor:
grad_out = grads[dist.get_rank()].contiguous()
dist.all_reduce(grad_out, op=dist.ReduceOp.SUM)
return grad_out
def backward(ctx, *grads: torch.Tensor) -> torch.Tensor:
all_grads = torch.stack(grads)
dist.all_reduce(all_grads, op=dist.ReduceOp.SUM)
return all_grads[dist.get_rank()]
6 changes: 6 additions & 0 deletions cellarium/ml/distributions/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Copyright Contributors to the Cellarium project.
# SPDX-License-Identifier: BSD-3-Clause

from cellarium.ml.distributions.negative_binomial import NegativeBinomial

__all__ = ["NegativeBinomial"]
81 changes: 81 additions & 0 deletions cellarium/ml/distributions/negative_binomial.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# Copyright Contributors to the Cellarium project.
# SPDX-License-Identifier: BSD-3-Clause

from numbers import Number

import torch
from pyro.distributions import TorchDistribution, constraints
from torch.distributions.utils import broadcast_all, lazy_property


class NegativeBinomial(TorchDistribution):
"""Negative binomial distribution.
Args:
mu:
Mean of the distribution.
theta:
Inverse dispersion.
"""

THETA_THRESHOLD_STIRLING_SWITCH = 200

arg_constraints = {"mu": constraints.greater_than_eq(0), "theta": constraints.greater_than_eq(0)}
support = constraints.nonnegative_integer

def __init__(self, mu: torch.Tensor, theta: torch.Tensor, validate_args: bool | None = None) -> None:
self.mu, self.theta = broadcast_all(mu, theta)
if isinstance(mu, Number) and isinstance(theta, Number):
batch_shape = torch.Size()
else:
batch_shape = self.mu.size()
super().__init__(batch_shape, validate_args=validate_args)

@property
def mean(self) -> torch.Tensor:
return self.mu

@property
def variance(self) -> torch.Tensor:
return (self.mu + (self.mu**2) / self.theta).masked_fill(self.theta == 0, 0)

@lazy_property
def _gamma(self) -> torch.distributions.Gamma:
# Note we avoid validating because self.theta can be zero.
return torch.distributions.Gamma(
concentration=self.theta,
rate=(self.theta / self.mu).masked_fill(self.theta == 0, 1.0),
validate_args=False,
)

def sample(self, sample_shape=torch.Size()) -> torch.Tensor:
with torch.no_grad():
rate = self._gamma.sample(sample_shape=sample_shape)
return torch.poisson(rate)

def log_prob(self, value: torch.Tensor) -> torch.Tensor: # type: ignore[override]
if self._validate_args:
self._validate_sample(value)

# Original implementation from scVI:
#
# log_theta_mu_eps = torch.log(self.theta + self.mu + self.eps)
# return (
# self.theta * (torch.log(self.theta + self.eps) - log_theta_mu_eps)
# + value * (torch.log(self.mu + self.eps) - log_theta_mu_eps)
# + torch.lgamma(value + self.theta)
# - torch.lgamma(self.theta)
# - torch.lgamma(value + 1)
# )
delta = torch.where(
self.theta > self.THETA_THRESHOLD_STIRLING_SWITCH,
(value + self.theta - 0.5) * torch.log1p(value / self.theta) - value,
(value + self.theta).lgamma() - self.theta.lgamma() - torch.xlogy(value, self.theta),
)
# The case self.theta == 0 and value == 0 has probability 1.
# The case self.theta == 0 and value != 0 has probability 0.
return (
(delta - (value + self.theta) * torch.log1p(self.mu / self.theta)).masked_fill(self.theta == 0, 0)
- (value + 1).lgamma()
+ torch.xlogy(value, self.mu)
)
7 changes: 7 additions & 0 deletions docs/source/distributions.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Distributions
=============

.. automodule:: cellarium.ml.distributions
:members:
:special-members:
:show-inheritance:
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ Table of Contents
core
data
distributed
distributions
lr_schedulers
models
transforms
Expand Down
33 changes: 18 additions & 15 deletions tests/distributed/test_gather.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# Copyright Contributors to the Cellarium project.
# SPDX-License-Identifier: BSD-3-Clause

import os

import torch
Expand All @@ -23,14 +26,11 @@ def run(rank: int, world_size: int, return_dict: dict) -> None:
# | / \ |
# loss_0 loss_1

# loss_0 = w_0 + w_1
# loss_1 = w_0 + w_1

# dloss/dw_0 = dloss_0/dw_0 + dloss_1/dw_0 = 2
# dloss/dw_1 = dloss_0/dw_1 + dloss_1/dw_1 = 2
w = torch.ones(1, requires_grad=True) # w_rank
gathered_w = GatherLayer.apply(w) # (w_0, w_1)
loss = gathered_w[0] + gathered_w[1] # w_0 + w_1
# loss_rank = coeff[rank, 0] * w[0] + coeff[rank, 1] * w[1]
coeff = torch.tensor([[1, 2], [3, 4]])
w_rank = torch.ones(1, requires_grad=True) # w_rank
w = GatherLayer.apply(w_rank) # (w_0, w_1)
loss = coeff[rank, 0] * w[0] + coeff[rank, 1] * w[1]
loss.backward()
return_dict[rank] = w.grad

Expand All @@ -49,12 +49,15 @@ def test_gather_layer():
p.join()

# Single GPU
# loss = w + w
# dloss/dw = 2
w = torch.ones(1, requires_grad=True)
gathered_w = (w, w)
loss = gathered_w[0] + gathered_w[1] # w + w
# loss = coeff[0, 0] * w[0] + coeff[0, 1] * w[1] + coeff[1, 0] * w[0] + coeff[1, 1] * w[1]
# ---------------- rank 0 --------------- ---------------- rank 1 ---------------
# dloss/dw[0] = coeff[0, 0] + coeff[1, 0]
# dloss/dw[1] = coeff[0, 1] + coeff[1, 1]
coeff = torch.tensor([[1, 2], [3, 4]])
w = torch.tensor([1.0, 1.0], requires_grad=True)
loss = (coeff[0] * w).sum() + (coeff[1] * w).sum()
loss.backward()

for w_grad in return_dict.values():
assert w_grad == w.grad
for rank, w_grad in return_dict.items():
assert w.grad is not None
assert w_grad == w.grad[rank] == coeff[:, rank].sum()
64 changes: 64 additions & 0 deletions tests/distributions/test_negative_binomial.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Copyright Contributors to the Cellarium project.
# SPDX-License-Identifier: BSD-3-Clause

import numpy as np
import pyro.distributions as dist
import pytest
import torch

from cellarium.ml.distributions import NegativeBinomial


@pytest.mark.parametrize("logits_shape", [(), (2,), (3, 2)])
@pytest.mark.parametrize("total_counts_shape", [(), (2,), (3, 2)])
def test_negative_binomial(logits_shape: torch.Size, total_counts_shape: torch.Size) -> None:
logits = torch.randn(logits_shape)
total_counts = torch.rand(total_counts_shape) * 10

if len(total_counts_shape) == 2:
total_counts[0] = 0

pyro_dist = dist.NegativeBinomial(total_counts, logits=logits) # type: ignore[attr-defined]

mu = torch.exp(logits) * total_counts
theta = total_counts
cellarium_nb = NegativeBinomial(mu, theta)

# shape
assert cellarium_nb.batch_shape == pyro_dist.batch_shape

# mean
np.testing.assert_allclose(cellarium_nb.mean, pyro_dist.mean, rtol=1e-5)

# variance
np.testing.assert_allclose(cellarium_nb.variance, pyro_dist.variance, rtol=1e-5)

# log_prob
value = torch.randint(20, size=(3, 2))
if len(total_counts_shape) == 2:
value[0, 0] = 0
value[0, 1] = 2.0
pyro_log_prob = pyro_dist.log_prob(value)
cellarium_log_prob = cellarium_nb.log_prob(value)
np.testing.assert_allclose(pyro_log_prob, cellarium_log_prob, rtol=1e-5)

# sample
samples = cellarium_nb.sample(torch.Size([50_000]))

expected_mean = cellarium_nb.mean
actual_mean = samples.mean(0)
np.testing.assert_allclose(actual_mean, expected_mean, atol=0.02, rtol=0.05)

expected_var = cellarium_nb.variance
actual_var = samples.var(0)
np.testing.assert_allclose(actual_var, expected_var, atol=0.02, rtol=0.05)


@pytest.mark.parametrize("mu", torch.logspace(-4, 3, 8))
@pytest.mark.parametrize("theta", torch.logspace(-2, 6, 9))
def test_total_probability(mu: torch.Tensor, theta: torch.Tensor) -> None:
values = torch.arange(0, 2 + int(mu * 1e3))
log_probs = NegativeBinomial(mu, theta).log_prob(values)
expected = torch.tensor(0.0)
actual = log_probs.logsumexp(0)
assert torch.allclose(actual, expected, atol=5e-4)

0 comments on commit b4d4202

Please sign in to comment.