Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

First update #6

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ seaborn_objects_recipes is a Python package that extends the functionality of th
- [LineLabel](https://github.com/Ofosu-Osei/seaborn_objects_recipes/blob/main/seaborn_objects_recipes/recipes/line_label.py)
- [Lowess](https://github.com/Ofosu-Osei/seaborn_objects_recipes/blob/main/seaborn_objects_recipes/recipes/lowess.py)
- [PolyFitCI](https://github.com/Ofosu-Osei/seaborn_objects_recipes/blob/main/seaborn_objects_recipes/recipes/plotting.py)
- [PolyFit](https://github.com/Ofosu-Osei/seaborn_objects_recipes/blob/main/seaborn_objects_recipes/recipes/plotting.py)

## Installation

Expand Down Expand Up @@ -221,6 +222,41 @@ def test_regression_with_ci(cleanup_files):

![regwithci](img/reg_with_ci.png)

### PolyFit
```python
def test_polyfit_with_ci(cleanup_files):
# Load the penguins dataset
penguins = sns.load_dataset("penguins")

# Prepare data
data = penguins[penguins['species'] == 'Adelie']

# Initialize PolyFit instance with bootstrapping
poly_fit_with_bootstrap = sor.PolyFit(order=2, gridsize=100, num_bootstrap=200, alpha=0.05)

# Call the PolyFit method on prepared data
results_with_bootstrap = poly_fit_with_bootstrap(data, 'bill_length_mm', 'body_mass_g')

# Plotting
fig, ax = plt.subplots(figsize=(9, 5))
sns.scatterplot(x='bill_length_mm', y='body_mass_g', data=data, ax=ax, color='blue', alpha=0.5)
ax.plot(results_with_bootstrap['bill_length_mm'], results_with_bootstrap['body_mass_g'], color='darkblue')
if 'ci_lower' in results_with_bootstrap.columns and 'ci_upper' in results_with_bootstrap.columns:
ax.fill_between(results_with_bootstrap['bill_length_mm'],
results_with_bootstrap['ci_lower'],
results_with_bootstrap['ci_upper'],
color='blue',
alpha=0.3)
ax.set_xlabel('Bill Length (mm)')
ax.set_ylabel('Body Mass (g)')
ax.set_title('Polynomial Fit with Confidence Intervals for Adelie Penguins')
ax.grid(True, which='both', color='gray', linewidth=0.5, linestyle='--')
plt.show()
```
### Output

![regwithci](img/polyfit_with_ci.png)


## Contact

Expand Down
Binary file added img/polyfit_with_ci.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added lowess.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
4 changes: 3 additions & 1 deletion seaborn_objects_recipes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,7 @@

from .recipes.plotting import PolyFitCI # noqa: F401

__all__ = ['Rolling', 'LineLabel', 'Lowess', 'PolyFitCI']
from .recipes.plotting import PolyFit # noqa: F401

__all__ = ['Rolling', 'LineLabel', 'Lowess', 'PolyFitCI','PolyFit']

62 changes: 61 additions & 1 deletion seaborn_objects_recipes/recipes/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@
from matplotlib import style
from dataclasses import dataclass
import statsmodels.formula.api as smf
from seaborn._stats.base import Stat
import seaborn.objects as so
import pandas as pd
import numpy as np
from typing import Optional

@dataclass
class PolyFitCI(so.PolyFit):
Expand Down Expand Up @@ -61,4 +63,62 @@ def plot(self, data, xvar, yvar):


return plot




@dataclass
class PolyFit(Stat):
"""
Fit a polynomial of the given order and resample data onto predicted curve
including confidence intervals.
"""
alpha: float = 0.05
order: int = 2
gridsize: int = 100
num_bootstrap: Optional[int] = None

def __post_init__(self):
# Type checking for the arguments
if not isinstance(self.order, int) or self.order <= 0:
raise ValueError("order must be a positive integer.")
if not isinstance(self.gridsize, int) or self.gridsize <= 0:
raise ValueError("gridsize must be a positive integer.")
if self.num_bootstrap is not None and (not isinstance(self.num_bootstrap, int) or self.num_bootstrap <= 0):
raise ValueError("num_bootstrap must be a positive integer or None.")
if not isinstance(self.alpha, float) or not (0 < self.alpha < 1):
raise ValueError("alpha must be a float between 0 and 1.")

def _fit_predict(self, data):
data = data.dropna(subset=["x", "y"])
x = data["x"]
y = data["y"]
if x.nunique() <= self.order:
xx = yy = []
else:
p = np.polyfit(x, y, self.order)
xx = np.linspace(x.min(), x.max(), self.gridsize)
yy = np.polyval(p, xx)

results = pd.DataFrame(dict(x=xx, y=yy))

if self.num_bootstrap:
bootstrap_estimates = np.empty((self.num_bootstrap, len(xx)))
for i in range(self.num_bootstrap):
sample = data.sample(frac=1, replace=True)
p = np.polyfit(sample["x"], sample["y"], self.order)
yy_sample = np.polyval(p, xx)
bootstrap_estimates[i, :] = yy_sample

lower_bound = np.percentile(bootstrap_estimates, (1 - self.alpha) / 2 * 100, axis=0)
upper_bound = np.percentile(bootstrap_estimates, (1 + self.alpha) / 2 * 100, axis=0)
results["ci_lower"] = lower_bound
results["ci_upper"] = upper_bound

return results

def __call__(self, data, xvar, yvar):
# Rename columns to match expected input for _fit_predict
data_renamed = data.rename(columns={xvar: "x", yvar: "y"})
#return groupby.apply(data_renamed.dropna(subset=["x", "y"]), self._fit_predict)
results = self._fit_predict(data_renamed)
return results.rename(columns={"x": xvar, "y": yvar})
514 changes: 514 additions & 0 deletions test.ipynb

Large diffs are not rendered by default.

38 changes: 37 additions & 1 deletion test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ def cleanup_files():
os.remove("lowess_nb.png")
if os.path.exists("reg_with_ci.png"):
os.remove("reg_with_ci.png")
if os.path.exists("polyfit_with_ci.png"):
os.remove("polyfit_with_ci.png")


# Use the sample_data fixture to provide data to the test function
Expand Down Expand Up @@ -211,4 +213,38 @@ def test_regression_with_ci(cleanup_files):
#plt.show()

# Assert that the file was created
assert os.path.exists("reg_with_ci.png"), "The plot file lowess.png was not created."
assert os.path.exists("reg_with_ci.png"), "The plot file lowess.png was not created."


def test_polyfit_with_ci(cleanup_files):
# Load the penguins dataset
penguins = sns.load_dataset("penguins")

# Prepare data
data = penguins[penguins['species'] == 'Adelie']

# Initialize PolyFit instance with bootstrapping
poly_fit_with_bootstrap = sor.PolyFit(order=2, gridsize=100, num_bootstrap=200, alpha=0.05)


# Call the PolyFit method on prepared data
results_with_bootstrap = poly_fit_with_bootstrap(data, 'bill_length_mm', 'body_mass_g')

# Plotting
fig, ax = plt.subplots(figsize=(9, 5))
sns.scatterplot(x='bill_length_mm', y='body_mass_g', data=data, ax=ax, color='blue', alpha=0.5)
ax.plot(results_with_bootstrap['bill_length_mm'], results_with_bootstrap['body_mass_g'], color='darkblue')
if 'ci_lower' in results_with_bootstrap.columns and 'ci_upper' in results_with_bootstrap.columns:
ax.fill_between(results_with_bootstrap['bill_length_mm'],
results_with_bootstrap['ci_lower'],
results_with_bootstrap['ci_upper'],
color='blue',
alpha=0.3)
ax.set_xlabel('Bill Length (mm)')
ax.set_ylabel('Body Mass (g)')
ax.set_title('Polynomial Fit with Confidence Intervals for Adelie Penguins')
ax.grid(True, which='both', color='gray', linewidth=0.5, linestyle='--')
#plt.show()
plt.savefig("polyfit_with_ci.png")
# Assert that the file was created
assert os.path.exists("polyfit_with_ci.png"), "The plot file lowess.png was not created."
Loading