Skip to content

Commit

Permalink
tsne with legend
Browse files Browse the repository at this point in the history
vinicvaz committed Nov 27, 2023
1 parent c51787c commit e0eefd5
Showing 1 changed file with 28 additions and 2 deletions.
30 changes: 28 additions & 2 deletions pieces/TSNEPiece/piece.py
Original file line number Diff line number Diff line change
@@ -4,6 +4,7 @@
import pandas as pd
from pathlib import Path
import plotly.express as px
import plotly.graph_objects as go


class TSNEPiece(BasePiece):
@@ -28,10 +29,35 @@ def piece_function(self, input_data: InputModel):
tsne_df['target'] = df['target']

if input_data.n_components >= 2:
fig = go.Figure()
color_scale = px.colors.qualitative.Bold
if input_data.use_class_column:
fig = px.scatter(tsne_df, x='tsne_0', y='tsne_1', color='target')
unique_targets = tsne_df['target'].unique()
for idx, target_value in enumerate(unique_targets):
color = color_scale[idx % len(color_scale)]
filtered_data = tsne_df[tsne_df['target'] == target_value]
fig.add_trace(
go.Scatter(
x=filtered_data['tsne_0'],
y=filtered_data['tsne_1'],
mode='markers',
name=f'Target: {target_value}',
marker=dict(
color=color,
),
)
)
else:
fig = px.scatter(tsne_df, x='tsne_0', y='tsne_1')
color = color_scale[0]
fig.add_trace(
go.Scatter(
x=tsne_df['tsne_0'],
y=tsne_df['tsne_1'],
mode='markers',
)
)

# Create a combined figure from all separate traces
fig.update_layout(
title="t-SNE Projection - First two dimensions",
xaxis_title="First Dimension",

0 comments on commit e0eefd5

Please sign in to comment.