Skip to content

Commit

Permalink
Add bounds parameter to geom_density
Browse files Browse the repository at this point in the history
closes #796
  • Loading branch information
has2k1 committed Jun 22, 2024
1 parent 552c3cb commit 9302df8
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 8 deletions.
4 changes: 4 additions & 0 deletions doc/changelog.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ title: Changelog
- [](:class:`~plotnine.geom_text`) has gained new aesthetics
`fontvariant` and `fontstretch`.

- [](:stat:`~plotnine.stat_density`) has gained a new parameter `bounds`
that you can use remove asymptotic boundary effects that arise from
density estimates on an infinite domain. ({{< issue 796 >}})

### Bug Fixes

- Fix layers 3 and above not to overlap the axis lines if there are any
Expand Down
85 changes: 78 additions & 7 deletions plotnine/stats/stat_density.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import typing
from contextlib import suppress
from typing import TYPE_CHECKING, cast
from warnings import warn

import numpy as np
Expand All @@ -12,8 +12,8 @@
from ..mapping.evaluation import after_stat
from .stat import stat

if typing.TYPE_CHECKING:
from plotnine.typing import FloatArrayLike
if TYPE_CHECKING:
from plotnine.typing import FloatArray, FloatArrayLike


# NOTE: Parameter descriptions are in
Expand Down Expand Up @@ -77,6 +77,11 @@ class stat_density(stat):
clip : tuple[float, float], default=(-inf, inf)
Values in `x` that are outside of the range given by clip are
dropped. The number of values in `x` is then shortened.
bounds: tuple[float, float], default=(-inf, inf)
The domain boundaries of the data. When the domain is finite the
estimated density will be corrected to remove asymptotic boundary
effects that are usually biased away from the probability density
function being estimated.
See Also
--------
Expand Down Expand Up @@ -115,6 +120,7 @@ class stat_density(stat):
"bw": "nrd0",
"cut": 3,
"clip": (-np.inf, np.inf),
"bounds": (-np.inf, np.inf),
}
DEFAULT_AES = {"y": after_stat("density")}
CREATES = {"density", "count", "scaled", "n"}
Expand Down Expand Up @@ -165,12 +171,12 @@ def compute_density(x, weight, range, **params):
x = np.asarray(x, dtype=float)
not_nan = ~np.isnan(x)
x = x[not_nan]
bw = params["bw"]
bw = cast(str | float, params["bw"])
kernel = params["kernel"]
bounds = params["bounds"]
has_bounds = not (np.isneginf(bounds[0]) and np.isposinf(bounds[1]))
n = len(x)

assert isinstance(bw, (str, float)) # type narrowing

if n == 0 or (n == 1 and isinstance(bw, str)):
if n == 1:
warn(
Expand Down Expand Up @@ -211,7 +217,19 @@ def compute_density(x, weight, range, **params):
clip=params["clip"],
)

x2 = np.linspace(range[0], range[1], params["n"])
if has_bounds:
# kde.support is the grid over which the kernel function is
# defined and the first and last values of this grid are:
#
# [min(x)-cut*bw, max(x)+cut*bw]
#
# i.e. the grid is wider than the ptp range of x.
# Evaluating values beyond the ptp range helps us calculate a
# boundary corrections. So we widen the range over which we will
# evaluate, so that it contains all points supported by the grid.
x2 = np.linspace(kde.support[0], kde.support[-1], params["n"])
else:
x2 = np.linspace(range[0], range[1], params["n"])

try:
y = kde.evaluate(x2)
Expand All @@ -235,6 +253,10 @@ def compute_density(x, weight, range, **params):
not_nan = ~np.isnan(y)
x2 = x2[not_nan]
y = y[not_nan]

if has_bounds:
x2, y = fit_density_to_bounds(x2, y, range, bounds)

return pd.DataFrame(
{
"x": x2,
Expand Down Expand Up @@ -277,3 +299,52 @@ def nrd0(x: FloatArrayLike) -> float:
if low_std == 0:
low_std = std_estimate or np.abs(np.asarray(x)[0]) or 1
return 0.9 * low_std * (n**-0.2)


def fit_density_to_bounds(
x: FloatArray,
y: FloatArray,
range: tuple[float, float],
bounds: tuple[float, float],
) -> tuple[FloatArray, FloatArray]:
"""
Fit calculated density to the given bounds
Parameters
----------
x :
Points at which the density is estimated. `x` is expected to
to include all values of the density grid.
y :
Estimated density.
range :
bounds :
Valid boundary (domain) of the x values.
Returns
-------
x_bound :
Points that fall within the bounds at which the density is
estimated.
y_bound :
Estimated densities at points within the bounds.
"""

def interpolate(x2: FloatArray) -> FloatArray:
# Interpolate (linearly) along the density function
# The values at points beyond (left or right) the original
# grid (x) are zero.
return np.interp(x2, x, y, left=0, right=0)

# The boundary corrections work by:
# 1. reflecting values outside the bounds so that they fall within
# the bounds to give a correction values
# 2. adding the correction values to the original density
new_range = max(range[0], bounds[0]), min(range[1], bounds[1])
x_bound = np.linspace(new_range[0], new_range[1], len(x))
y_bound = (
interpolate(x_bound)
+ interpolate(2 * bounds[0] - x_bound)
+ interpolate(2 * bounds[1] - x_bound)
)
return x_bound, y_bound
1 change: 1 addition & 0 deletions plotnine/stats/stat_sina.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ def setup_params(self, data):
params["cut"] = 0
params["gridsize"] = None
params["clip"] = (-np.inf, np.inf)
params["bounds"] = (-np.inf, np.inf)
params["n"] = 512
return params

Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
17 changes: 16 additions & 1 deletion tests/test_geom_density.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import numpy as np
import pandas as pd
import pytest
import scipy.stats as stats

from plotnine import aes, geom_density, ggplot, lims
from plotnine import aes, geom_density, ggplot, lims, stat_function
from plotnine.exceptions import PlotnineWarning

n = 6 # Some even number greater than 2
Expand Down Expand Up @@ -58,3 +59,17 @@ def test_few_datapoints():
+ lims(x=(0, 4))
)
assert p == "few_datapoints"


def test_bounds():
rs = np.random.RandomState(123)
data = pd.DataFrame({"x": rs.uniform(size=1000)})

p = (
ggplot(data, aes("x"))
+ geom_density()
+ geom_density(bounds=(0, 1), color="blue")
+ stat_function(fun=stats.uniform.pdf, color="red")
)

assert p == "bounds"

0 comments on commit 9302df8

Please sign in to comment.