如何在Python中修复plotly 3D散点图的图例?

问题描述 投票:0回答:1

我正在尝试使用Python中的plotly创建一些3D散点图。我有一个包含一些性格特征和聚类标签的数据框,我想使用颜色和符号来绘制每个特征的 TSNE 嵌入来表示特征和聚类值。这是我的代码:

from plotly.subplots import make_subplots

# Replace y and n with numbers

df2 = df.replace({'y': 1, 'n': 0})

# Assign cluster labels

df2 = df2.assign(Cluster=labels)

# Create a list of plot titles for each personality trait
plot_titles = ['Extraversion', 'Neuroticism', 'Agreeableness', 'Conscientiousness', 'Openness']

# Plot the TSNE embeddings for each personality trait
for i, trait in enumerate(['cEXT', 'cNEU', 'cAGR', 'cCON', 'cOPN']):
  # Create a single plot figure for each trait
  fig = px.scatter_3d(df2,
                      x=embeddings_3d[:, 0], y=embeddings_3d[:, 1], z=embeddings_3d[:, 2],
                      color = df2[trait], symbol=df2['Cluster'],
                      color_discrete_map={0: '#FF0000', 1: '#0000FF'},
                      size_max=1, symbol_map={0: 'circle', 1: 'square', 2: 'diamond', 3: 'cross', 4: 'x'},
                      opacity=0.3)
  # Update the layout of the plot figure with the title
  fig.update_layout(title=plot_titles[i])
  
  # Show the plot figure
  fig.show()

然而,剧情的传说却是一团糟。它显示了颜色和簇表示的渐变图例,但我不需要渐变。我的类别是标签,因此是分类的。只是名字而已!而且,这一切都是混合的。所以我需要呈现的只是每个数据点代表的集群的标签。

如何修复图的图例以仅显示具有相应颜色和符号的簇标签?任何帮助,将不胜感激。谢谢!

这是数据:df2

python 3d plotly legend scatter-plot
1个回答
0
投票

这是一个可能的解决方案的草案,因为我仍然缺少问题中未提供的一些参数(标签和嵌入_3d)。检查这是否是您正在寻找的内容。但请注意,我无法运行该代码,因此这是我刚刚在脑海中写下的内容。

from plotly.subplots import make_subplots
import plotly.express as px

# Replace y and n with numbers
df2 = df.replace({'y': 1, 'n': 0})

# Assign cluster labels
df2 = df2.assign(Cluster=labels)

# Create a list of plot titles for each personality trait
plot_titles = ['Extraversion', 'Neuroticism', 'Agreeableness', 'Conscientiousness', 'Openness']

# Initialize an empty list to store traces
traces = []

# Plot the TSNE embeddings for each personality trait
for i, trait in enumerate(['cEXT', 'cNEU', 'cAGR', 'cCON', 'cOPN']):
    # Create a scatter plot trace for the current trait
    trace = px.scatter_3d(df2,
                           x=embeddings_3d[:, 0], y=embeddings_3d[:, 1], z=embeddings_3d[:, 2],
                           color=df2[trait], symbol=df2['Cluster'],
                           color_discrete_map={0: '#FF0000', 1: '#0000FF'},
                           size_max=1, symbol_map={0: 'circle', 1: 'square', 2: 'diamond', 3: 'cross', 4: 'x'},
                           opacity=0.3,
                           title=plot_titles[i],
                           showlegend=True)  # Set showlegend to True
    # Append the trace to the list
    traces.append(trace)

# Create a subplot with all the traces
fig = make_subplots(rows=1, cols=len(traces))

for i, trace in enumerate(traces):
    # Add each trace to the subplot
    fig.add_trace(trace, row=1, col=i + 1)

# Update the subplot layout
fig.update_layout(title_text="Personality Traits and Cluster Labels", showlegend=True)

# Show the plot
fig.show()

© www.soinside.com 2019 - 2024. All rights reserved.