Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main'
Browse files Browse the repository at this point in the history
  • Loading branch information
Y-nuclear committed Dec 11, 2023
2 parents ebf3137 + 78149a6 commit ef14f3a
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion src/gnnwr/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,7 +475,7 @@ def init_dataset(data, test_ratio, valid_ratio, x_column, y_column, spatial_colu
id_column=None, sample_seed=42, process_fn="minmax_scale", batch_size=32, shuffle=True,
use_class=baseDataset,
spatial_fun=BasicDistance, temporal_fun=Manhattan_distance, max_val_size=-1, max_test_size=-1,
from_for_cv=0, is_need_STNN=False, Reference=None, simple_distance=True):
from_for_cv=0, is_need_STNN=False, Reference=None, simple_distance=True, dropna=True):
"""
Initialize the dataset and return the training set, validation set and test set for the model
Expand Down Expand Up @@ -511,6 +511,11 @@ def init_dataset(data, test_ratio, valid_ratio, x_column, y_column, spatial_colu
# if dist_column is None, raise error
raise ValueError(
"dist_column must be a column name in data")
if dropna:
oriLen = data.shape[0]
data.dropna(axis=0,how='any',inplace=True)
if oriLen > data.shape[0]:
warnings.warn("Dropping {} {} with missing values. To forbid dropping, you need to set the argument dropna=False".format(oriLen - data.shape[0],'row' if oriLen - data.shape[0] == 1 else 'rows'))
if id_column is None:
id_column = ['id']
if 'id' not in data.columns:
Expand Down

0 comments on commit ef14f3a

Please sign in to comment.