diff --git a/sdv/_utils.py b/sdv/_utils.py index 1a864a422..802f031ec 100644 --- a/sdv/_utils.py +++ b/sdv/_utils.py @@ -453,14 +453,20 @@ def _is_numerical(value): def _prepare_data_vizualisation(data, metadata, column_names, sample_size): - """Prepare the data for a column pair plot. + """Prepare the data for a plot. Args: - data (pd.DataFrame): + data (pd.DataFrame or None): The data to be prepared. + metadata (Metadata): + The metadata of the data. + column_names (str or list[str]): + The column names to plot. + sample_size (int or None): + The number of samples to plot. If ``None``, use the whole dataset. Returns: - pd.DataFrame: + pd.DataFrame or None: The prepared data. """ if data is None: