我正在尝试使用共享的x轴创建2个图,我遇到了2个问题:
yaxis
和yaxis2
标题和/或标记自定义布局,y轴开始重叠以下是重现我遇到的问题的代码:
from plotly.offline import init_notebook_mode, iplot
init_notebook_mode(connected=True) # using jupyter
import plotly.graph_objs as go
from plotly import tools
import numpy as np
N = 100
epoch_range = [i for i in range(N)]
model_perf = {}
for m in ['acc','loss']:
for sub in ['train','validation']:
if sub == 'train':
history_target = m
else:
history_target = 'val_{}'.format(m)
model_perf[history_target] = np.random.random(N)
line_type = {
'train': dict(
color='grey',
width=1,
dash='dash'
),
'validation': dict(
color='blue',
width=4
)
}
fig = tools.make_subplots(rows=2, cols=1, shared_xaxes=True, shared_yaxes=False, specs = [[{'b':10000}], [{'b':10000}]])
i = 0
for m in ['acc','loss']:
i += 1
for sub in ['train','validation']:
if sub == 'train':
history_target = m
else:
history_target = 'val_{}'.format(m)
fig.append_trace({
'x': epoch_range,
'y': model_perf[history_target],
#'type': 'scatter',
'name': sub,
'legendgroup': m,
'yaxis': dict(title=m),
'line': line_type[sub],
'showlegend': True
}, i, 1)
fig['layout'].update(
height=600,
width=800,
xaxis = dict(title = 'Epoch'),
yaxis = dict(title='Accuracy', tickformat=".0%"),
yaxis2 = dict(title='Loss', tickformat=".0%"),
title='Performance'
)
iplot(fig)
如果您对如何解决这两个问题有任何建议,我很乐意听取您的意见。
曼尼提前谢谢!
编辑:
按照Farbice的建议,我查看了create_facet_grid
的plotly.figure_factory
函数(顺便说一下,需要2.0.12+),我确实设法用更少的线重现相同的图像,但它给了我较少的灵活性 - 例如我不认为你可以使用这个函数绘制线条,它也有传说重复问题,但如果你正在寻找一个临时的viz,这可能是非常有效的。它需要长格式的数据,请参见下面的示例:
# converting into the long format
import pandas as pd
perf_df = (
pd.DataFrame({
'accuracy_train': model_perf['acc'],
'accuracy_validation': model_perf['val_acc'],
'loss_train': model_perf['loss'],
'loss_validation': model_perf['val_loss']
})
.stack()
.reset_index()
.rename(columns={
'level_0': 'epoch',
'level_1': 'variable',
0: 'value'
})
)
perf_df = pd.concat(
[
perf_df,
perf_df['variable']
.str
.extractall(r'(?P<metric>^.*)_(?P<set>.*$)')
.reset_index()[['metric','set']]
], axis=1
).drop(['variable'], axis=1)
perf_df.head() # result
epoch value metric set
0 0.434349 accuracy train
0 0.374607 accuracy validation
0 0.864698 loss train
0 0.007445 loss validation
1 0.553727 accuracy train
# plot it
fig = ff.create_facet_grid(
perf_df,
x='epoch',
y='value',
facet_row='metric',
color_name='set',
scales='free_y',
ggplot2=True
)
fig['layout'].update(
height=800,
width=1000,
yaxis1 = dict(tickformat=".0%"),
yaxis2 = dict(tickformat=".0%"),
title='Performance'
)
iplot(fig)
根据我的经验,可视化工具更喜欢长格式的数据。您可能希望将数据调整为包含以下列的表:
通过这样做,你可能会发现通过在'变量'上使用facetting来创建所需的图表更容易,其中'set'-trace具有x = epoch,y = value
如果您有编码解决方案,请提供一些数据。
希望这有用。
在做了一点挖掘之后,我找到了解决这两个问题的方法。
首先,重叠的y轴问题是由布局更新中的yaxis
参数引起的,它必须更改为yaxis1
。
传奇中重复的第二个问题有点棘手,但this帖子帮助我解决了问题。我们的想法是每条轨迹都有一个与之相关的图例,因此如果您要绘制多条轨迹,您可能只想使用其中一条的图例(使用showlegend
参数),但要确保一个图例控制切换多个子图,您可以使用legendgroup
参数。
以下是解决方案的完整代码:
from plotly.offline import init_notebook_mode, iplot
init_notebook_mode(connected=True) # using jupyter
import plotly.graph_objs as go
from plotly import tools
import numpy as np
N = 100
epoch_range = [i for i in range(N)]
model_perf = {}
for m in ['acc','loss']:
for sub in ['train','validation']:
if sub == 'train':
history_target = m
else:
history_target = 'val_{}'.format(m)
model_perf[history_target] = np.random.random(N)
line_type = {
'train': dict(
color='grey',
width=1,
dash='dash'
),
'validation': dict(
color='blue',
width=4
)
}
fig = tools.make_subplots(
rows=2,
cols=1,
shared_xaxes=True,
shared_yaxes=False
)
i = 0
for m in ['acc','loss']:
i += 1
if m == 'acc':
legend_display = True
else:
legend_display = False
for sub in ['train','validation']:
if sub == 'train':
history_target = m
else:
history_target = 'val_{}'.format(m)
fig.append_trace({
'x': epoch_range,
'y': model_perf[history_target],
'name': sub,
'legendgroup': sub, # toggle train / test group on all subplots
'yaxis': dict(title=m),
'line': line_type[sub],
'showlegend': legend_display # this is now dependent on the trace
}, i, 1)
fig['layout'].update(
height=600,
width=800,
xaxis = dict(title = 'Epoch'),
yaxis1 = dict(title='Accuracy', tickformat=".0%"),
yaxis2 = dict(title='Loss', tickformat=".0%"),
title='Performance'
)
iplot(fig)
这是结果: