Skip to content

Commit

Permalink
Added title, colors, and updated how multiple sequences are handled
Browse files Browse the repository at this point in the history
  • Loading branch information
lajohn4747 committed Nov 8, 2023
1 parent d267933 commit 39fec6e
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 33 deletions.
104 changes: 86 additions & 18 deletions sdmetrics/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pandas as pd
import plotly.express as px
import plotly.figure_factory as ff
import plotly.graph_objects as go
from pandas.api.types import is_datetime64_dtype

from sdmetrics.reports.utils import PlotConfig
Expand Down Expand Up @@ -480,12 +481,14 @@ def get_column_pair_plot(real_data, synthetic_data, column_names, plot_type=None
return _generate_box_plot(all_data, columns)


def _generate_line_plot(all_data, x_axis, y_axis, marker, annotations=None):
def _generate_line_plot(real_data, synthetic_data, x_axis, y_axis, marker, annotations=None):
"""Generate a line plot of the real and synthetic data separated by a marker column.
Args:
all_data (pandas.DataFrame):
The combined data (real and synthetic) used for the graph.
real_data (pandas.DataFrame):
The real table data.
synthetic_column (pandas.Dataframe):
The synthetic table data.
x_axis (str):
The column name to be used as the x-axis of the graph
y_axis (str):
Expand All @@ -499,14 +502,68 @@ def _generate_line_plot(all_data, x_axis, y_axis, marker, annotations=None):
plotly.graph_objects._figure.Figure
"""
# Check if the column is the appropriate type
all_data = pd.concat([real_data, synthetic_data], axis=0, ignore_index=True)
if not (is_datetime(all_data[x_axis]) or
pd.api.types.is_numeric_dtype(all_data[x_axis])):
raise ValueError(
f"Sequence Index '{x_axis}' must contain numerical or datetime values only")

fig = px.line(all_data, x=x_axis, y=y_axis, color=marker, markers=True)
fig = px.line(all_data, x=x_axis, y=y_axis, color=marker, markers=True, color_discrete_map={
'Real': PlotConfig.DATACEBO_DARK, 'Synthetic': PlotConfig.DATACEBO_GREEN
})
if annotations:
fig.add_annotation(annotations)

fig.update_layout(
title_text=f"Real vs Synthetic Data for column: '{y_axis}'",
font={'size': PlotConfig.FONT_SIZE},
)

# Add min-max shading
if 'min' in all_data and 'max' in all_data:
fig.add_trace(
go.Scatter(
name='Real-Min',
x=real_data[x_axis],
y=real_data['min'],
marker={'color': 'rgba(0, 0, 54, 0.25)'},
showlegend=False,
mode='lines'
)
)
fig.add_trace(
go.Scatter(
name='Real-Max',
x=real_data[x_axis],
y=real_data['max'],
marker={'color': 'rgba(0, 0, 54, 0.25)'},
showlegend=False,
mode='lines',
fill='tonexty',
fillcolor='rgba(0, 0, 54, 0.25)',
)
)
fig.add_trace(
go.Scatter(
name='Synthetic-Min',
x=synthetic_data[x_axis],
y=synthetic_data['min'],
marker={'color': 'rgba(1, 224, 201, 0.25)'},
showlegend=False,
mode='lines'
)
)
fig.add_trace(
go.Scatter(
name='Synthetic-Max',
x=synthetic_data[x_axis],
y=synthetic_data['max'],
marker={'color': 'rgba(1, 224, 201, 0.25)'},
showlegend=False,
mode='lines',
fill='tonexty',
fillcolor='rgba(1, 224, 201, 0.25)',
)
)
return fig


Expand Down Expand Up @@ -556,30 +613,41 @@ def get_column_line_plot(real_data, synthetic_data, column_name, metadata):
# Merge the real and synthetic data and add a flag ``Data`` to indicate each one.
r_data = real_data.copy()
s_data = synthetic_data.copy()
marker_name = 'Data'

# If there are multiple sequences in the data, split them out appropriately
if 'sequence_key' in metadata:
key_column = metadata['sequence_key']
r_data[marker_name] = 'Real-' + r_data[key_column]
s_data[marker_name] = 'Synthetic-' + s_data[key_column]
else:
r_data[marker_name] = 'Real'
s_data[marker_name] = 'Synthetic'

# Check for sequence index to determine the x-axis values
x_axis = 'sequence_index'
y_axis = column_name
if 'sequence_index' in metadata:
x_axis = metadata['sequence_index']
if 'sequence_key' in metadata:
r_data = r_data.groupby(x_axis, as_index=False).agg(
{
x_axis: 'first',
column_name: ['mean', 'min', 'max']
}
).rename(columns={'mean': column_name, 'first': x_axis})
s_data = s_data.groupby(x_axis, as_index=False).agg(
{
x_axis: 'first',
column_name: ['mean', 'min', 'max']
}
).rename(columns={'mean': column_name, 'first': x_axis})

r_data.columns = r_data.columns.droplevel(0)
s_data.columns = s_data.columns.droplevel(0)
else:
r_data['sequence_index'] = r_data.index
s_data['sequence_index'] = s_data.index

marker_name = 'Data'
r_data[marker_name] = 'Real'
s_data[marker_name] = 'Synthetic'

# Generate plot
all_data = pd.concat([r_data, s_data], axis=0, ignore_index=True)
fig = _generate_line_plot(all_data=all_data,
fig = _generate_line_plot(real_data=r_data,
synthetic_data=s_data,
x_axis=x_axis,
y_axis=column_name,
y_axis=y_axis,
marker=marker_name,
annotations=annotations)
return fig
48 changes: 33 additions & 15 deletions tests/unit/test_visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,24 +469,23 @@ def test__generate_heatmap_plot(px_mock):
def test__generate_line_plot(px_mock):
"""Test the ``_generate_line_plot`` method."""
# Setup
real_column = pd.DataFrame({
real_data = pd.DataFrame({
'colX': [1, 2, 3, 4],
'colY': [10, 4, 20, 21],
'Data': ['Real'] * 4
})
synthetic_column = pd.DataFrame({
synthetic_data = pd.DataFrame({
'colX': [1, 2, 4, 5],
'colY': [6, 11, 9, 18],
'Data': ['Synthetic'] * 4
})

all_data = pd.concat([real_column, synthetic_column], axis=0, ignore_index=True)

mock_figure = Mock()
px_mock.line.return_value = mock_figure

# Run
fig = _generate_line_plot(all_data, x_axis='colX', y_axis='colY', marker='Data')
fig = _generate_line_plot(real_data, synthetic_data, x_axis='colX',
y_axis='colY', marker='Data')

# Assert
px_mock.line.assert_called_once_with(
Expand All @@ -507,25 +506,24 @@ def test__generate_line_plot(px_mock):
x='colX',
y='colY',
color='Data',
markers=True
markers=True,
color_discrete_map={'Real': '#000036', 'Synthetic': '#01E0C9'}
)
assert mock_figure.update_layout.called_once()
assert mock_figure.for_each_annotation.called_once()
assert fig == mock_figure

# Setup failing case
bad_column = pd.DataFrame({
bad_data = pd.DataFrame({
'colX': [1, 'bad_value', 4, 5],
'colY': [6, 7, 9, 18],
'Data': ['Synthetic'] * 4
})

bad_all_data = pd.concat([real_column, bad_column], axis=0, ignore_index=True)

# Run and Assert
match = "Sequence Index 'colX' must contain numerical or datetime values only"
with pytest.raises(ValueError, match=match):
_generate_line_plot(bad_all_data, x_axis='colX', y_axis='colY', marker='Data')
_generate_line_plot(real_data, bad_data, x_axis='colX', y_axis='colY', marker='Data')


@patch('sdmetrics.visualization.px')
Expand Down Expand Up @@ -810,12 +808,32 @@ def test_get_column_line_plot(mock__generate_line_plot):
fig = get_column_line_plot(real_data, synthetic_data, column_name='amount', metadata=metadata)

# Assert
real_data['Data'] = 'Real-' + real_data['object']
synthetic_data['Data'] = 'Synthetic-' + synthetic_data['object']
expected_call_data = pd.concat([real_data, synthetic_data], axis=0, ignore_index=True)

# Setup
real_data_submitted = pd.DataFrame({
'date': pd.to_datetime(
[
'2021-01-01', '2022-01-01', '2023-01-01'
]
),
'amount': [1.5, 3, 4.5],
'min': [1, 2, 3],
'max': [2, 4, 6],
'Data': ['Real', 'Real', 'Real'],
})
synthetic_data_submitted = pd.DataFrame({
'date': pd.to_datetime(
[
'2021-01-01', '2022-01-01', '2023-01-01'
]
),
'amount': [4.0, 2.0, 2.0],
'min': [4., 1., 1.],
'max': [4., 3., 3.],
'Data': ['Synthetic', 'Synthetic', 'Synthetic']
})
mock__generate_line_plot.assert_called_once_with(
all_data=DataFrameMatcher(expected_call_data),
real_data=DataFrameMatcher(real_data_submitted),
synthetic_data=DataFrameMatcher(synthetic_data_submitted),
x_axis='date',
y_axis='amount',
marker='Data',
Expand Down

0 comments on commit 39fec6e

Please sign in to comment.