Skip to content

Commit

Permalink
resolve comments
Browse files Browse the repository at this point in the history
  • Loading branch information
jitingxu1 committed Apr 5, 2024
1 parent eb12a08 commit 0f47259
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 12 deletions.
11 changes: 4 additions & 7 deletions ibisml/steps/feature_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ class ZeroVariance(Step):
To remove all numeric columns with zero variance:
>>> step = ml.ZeroVariance(ml.numeric())
To remove all non-numeric columns with only one unique value:
>>> step = ml.ZeroVariance(ml.string())
To remove all string or categorical columns with only one unique value:
>>> step = ml.ZeroVariance(ml.norminal())
"""

def __init__(self, inputs: SelectionType, *, tolerance: int | float = 1e-4):
Expand Down Expand Up @@ -64,14 +64,11 @@ def fit_table(self, table: ir.Table, metadata: Metadata) -> None:
# Check variance for numeric columns
if results[f"{name}_var"] < self.tolerance:
cols.append(name)
elif not isinstance(c, ir.NumericColumn) and results[f"{name}_var"] < 2:
elif results[f"{name}_var"] < 2:
# Check unique count for non-numeric columns
cols.append(name)

self.cols_ = cols

def transform_table(self, table: ir.Table) -> ir.Table:
if len(self.cols_) > 0:
return table.drop(self.cols_)
else:
return table
return table.drop(self.cols_)
11 changes: 6 additions & 5 deletions tests/test_feature_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@ def test_zero_variance():
start_timestamp + pd.Timedelta(minutes=i) for i in range(10)
]

non_zv_cols = {
"non_zero_variance_numeric_col",
"non_zero_variance_string_col",
"non_zero_variance_timestamp_col",
zv_cols = {
"zero_variance_numeric_col",
"zero_variance_string_col",
"zero_variance_timestamp_col",
}

t_train = ibis.memtable(
Expand All @@ -45,4 +45,5 @@ def test_zero_variance():
step = ml.ZeroVariance(ml.everything())
step.fit_table(t_train, ml.core.Metadata())
res = step.transform_table(t_test)
assert set(res.columns) == non_zv_cols
sol = t_test.drop(zv_cols)
assert sol.equals(res)

0 comments on commit 0f47259

Please sign in to comment.