Skip to content

Commit

Permalink
Merge pull request #21 from srivarra/lowess/groupby
Browse files Browse the repository at this point in the history
Added Groupby to Lowess
  • Loading branch information
Ofosu-Osei authored Oct 19, 2024
2 parents a657817 + 112a006 commit 0778b40
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 32 deletions.
5 changes: 2 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -201,18 +201,17 @@ import seaborn.objects as so
import seaborn as sns
import seaborn_objects_recipes as sor


def test_lowess_with_ci():

# Load the penguins dataset
penguins = sns.load_dataset("penguins")

# Prepare data
data = penguins.copy()
data = penguins[penguins['species'] == 'Adelie']

# Create the plot
plot = (
so.Plot(data, x="bill_length_mm", y="body_mass_g")
so.Plot(data, x="bill_length_mm", y="body_mass_g", color="species")
.add(so.Dot())
.add(so.Line(), lowess := sor.Lowess(frac=0.2, gridsize=100, num_bootstrap=200, alpha=0.95))
.add(so.Band(), lowess)
Expand Down
Binary file modified img/lowess_b.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
50 changes: 25 additions & 25 deletions seaborn_objects_recipes/recipes/lowess.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from __future__ import annotations
import numpy as np
import pandas as pd
from pandas import DataFrame
from dataclasses import dataclass
from seaborn._stats.base import Stat
import statsmodels.api as sm
from typing import Optional


@dataclass
class Lowess(Stat):
"""
Expand Down Expand Up @@ -39,34 +39,29 @@ class Lowess(Stat):
delta: float = 0.0
num_bootstrap: Optional[int] = None
alpha: float = 0.95


def __post_init__(self):
# Type checking for the arguments
if not isinstance(self.frac, float) or not (0 < self.frac <= 1):
raise ValueError("frac must be a float between 0 and 1.")
if not isinstance(self.gridsize, int) or self.gridsize <= 0:
raise ValueError("gridsize must be a positive integer.")
if self.num_bootstrap is not None and (
not isinstance(self.num_bootstrap, int) or self.num_bootstrap <= 0
):
if self.num_bootstrap is not None and (not isinstance(self.num_bootstrap, int) or self.num_bootstrap <= 0):
raise ValueError("num_bootstrap must be a positive integer or None.")
if not isinstance(self.alpha, float) or not (0 < self.alpha < 1):
raise ValueError("alpha must be a float between 0 and 1.")

def _fit_predict(self, data):
x = data["x"]
xx = np.linspace(x.min(), x.max(), self.gridsize)
result = sm.nonparametric.lowess(
endog=data["y"], exog=x, frac=self.frac, delta=self.delta, xvals=xx
)
result = sm.nonparametric.lowess(endog=data["y"], exog=x, frac=self.frac, delta=self.delta, xvals=xx)
if result.ndim == 1: # Handle single-dimensional return values
yy = result
else:
yy = result[:, 1] # Select the predicted y-values
return pd.DataFrame(dict(x=xx, y=yy))

def _bootstrap_resampling(self, data):
def _bootstrap_resampling(self, data) -> pd.DataFrame:
xx = np.linspace(data["x"].min(), data["x"].max(), self.gridsize)
bootstrap_estimates = np.empty((self.num_bootstrap, len(xx)))

Expand All @@ -81,34 +76,39 @@ def _bootstrap_resampling(self, data):
)
# Ensure the result is two-dimensional
if result.ndim == 1:
result = np.column_stack(
(xx, result)
) # Reformat to two-dimensional if needed
result = np.column_stack((xx, result)) # Reformat to two-dimensional if needed
bootstrap_estimates[i, :] = result[:, 1]

return xx, bootstrap_estimates
lower_bound = np.percentile(bootstrap_estimates, (1 - self.alpha) / 2 * 100, axis=0)
upper_bound = np.percentile(bootstrap_estimates, (1 + self.alpha) / 2 * 100, axis=0)

return pd.DataFrame({"ymin": lower_bound, "ymax": upper_bound})

def __call__(self, data: DataFrame, groupby, orient, scales) -> DataFrame:
def __call__(self, data: pd.DataFrame, groupby, orient, scales) -> pd.DataFrame:
if orient == "x":
xvar = data.columns[0]
yvar = data.columns[1]
else:
xvar = data.columns[1]
yvar = data.columns[0]

renamed_data = data.rename(columns={xvar: "x", yvar: "y"})
renamed_data = data.rename(columns={xvar: "x", yvar: "y"})
renamed_data = renamed_data.dropna(subset=["x", "y"])
smoothed = self._fit_predict(renamed_data)

grouping_vars = [str(v) for v in data if v in groupby.order]

if not grouping_vars:
# If no grouping variables, directly fit and predict
smoothed = self._fit_predict(renamed_data)
else:
# Apply the fit_predict method for each group separately
smoothed = groupby.apply(renamed_data, self._fit_predict)

if self.num_bootstrap:
xx, bootstrap_estimates = self._bootstrap_resampling(data)
lower_bound = np.percentile(
bootstrap_estimates, (1 - self.alpha) / 2 * 100, axis=0
)
upper_bound = np.percentile(
bootstrap_estimates, (1 + self.alpha) / 2 * 100, axis=0
)
smoothed["ymin"] = lower_bound
smoothed["ymax"] = upper_bound
if not grouping_vars:
bootstrap_estimates = self._bootstrap_resampling(data)
else:
bootstrap_estimates = groupby.apply(data, self._bootstrap_resampling)

return smoothed
return smoothed.join(bootstrap_estimates[["ymin", "ymax"]]) if self.num_bootstrap else smoothed
5 changes: 1 addition & 4 deletions test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,12 +113,9 @@ def test_lowess_with_ci(cleanup_files):
# Load the penguins dataset
penguins = sns.load_dataset("penguins")

# Prepare data
data = penguins[penguins['species'] == 'Adelie']

# Create the plot
plot = (
so.Plot(data, x="bill_length_mm", y="body_mass_g")
so.Plot(penguins, x="bill_length_mm", y="body_mass_g", color="species")
.add(so.Dot())
.add(so.Line(), lowess := sor.Lowess(frac=0.2, gridsize=100, num_bootstrap=200, alpha=0.95))
.add(so.Band(), lowess)
Expand Down

0 comments on commit 0778b40

Please sign in to comment.