如何确保 abline/identity 线在每个子图中以 1:1 居中,并在 Matplotlib 中具有单独的轴限制? [已编辑][已解决]

问题描述 投票:0回答:1

我有一个 2x2 子图网格,其中每个子图包含一个具有不同数据点的散点图。我试图在每个子图中绘制一条通用的 abline(斜率=1,截距=0)来可视化数据点之间的关系。然而,由于每个子图中的数据范围不同,abline 不会在所有子图中以 1:1 的比例显示。

我想确保 abline 在每个子图中以 1:1 居中,同时根据特定子图中的数据点维护每个图的单独轴限制。换句话说,我希望 abline 穿过每个子图数据点的中心而不扭曲数据。

有人可以指导我如何在每个子图中实现 abline 的正确居中,同时根据该子图中的数据点保持各个轴限制吗?

这就是代码:

timesteps = [185, 159, 53, 2]

def abline(ax, slope, intercept):
    """Plot a line from slope and intercept"""
    x_vals = np.array(ax.get_xlim())
    y_vals = intercept + slope * x_vals
    ax.plot(x_vals, y_vals, 'r--')

fig, axs = plt.subplots(2, 2, figsize=(12, 8))

for i, timestep in enumerate(timesteps):
    mask = np.where(nan_mask[timestep, :, :] == 0)
    data_tmwm_values = data_tmwm[timestep, :, :][mask]
    ds_plot_values = ds_og_red[timestep, :, :][mask]

    row = i // 2  # Integer division to get the row index
    col = i % 2  # Modulo operation to get the column index
    
    ax = axs[row, col]
    ax.scatter(data_tmwm_values, ds_plot_values, s=20)
    ax.set_xlabel('TMWM')
    ax.set_ylabel('Original')
    ax.set_title(f'Scatter Plot (Timestep: {timestep})')

    correlation_matrix = np.corrcoef(data_tmwm_values, ds_plot_values)
    r_value = correlation_matrix[0, 1]

    r_squared = r_value ** 2
    abline(ax, 1, 0)
    ax.text(0.05, 0.95, f"R\u00b2 value: {r_squared:.3f}", transform=ax.transAxes, ha='left', va='top')

plt.tight_layout()
plt.show()

这就是图像:

enter image description here

我已经尝试使用 get_xlim() 和 get_ylim() 函数来设置每个子图的轴限制,但它不会导致 abline 正确居中。

python matplotlib visualization scatter-plot subplot
1个回答
0
投票

您似乎想要一条恒等线,但您正在尝试线性拟合。线性拟合可能对您仍然有用,因为您计算各种相关性指标并覆盖 R2。

下面的示例展示了如何添加线性拟合以及恒等 (y=x) 线。

import matplotlib.pyplot as plt
import pandas as pd
import numpy as np

timesteps = [185, 159, 53, 2]

fig, axs = plt.subplots(2, 2, figsize=(12, 8))

for timestep, ax in zip(timesteps, axs.flatten()):
    #Synthetic data
    data_tmwm_values = np.random.randn(200) * 10 + timestep / 2
    ds_plot_values = np.random.randn(200) * 20 + timestep / 2

    ax.scatter(data_tmwm_values, ds_plot_values, s=20)
    ax.set_xlabel('TMWM')
    ax.set_ylabel('Original')
    ax.set_title(f'Scatter Plot (Timestep: {timestep})')

    correlation_matrix = np.corrcoef(data_tmwm_values, ds_plot_values)
    r_value = correlation_matrix[0, 1]
    r_squared = r_value ** 2
    ax.text(0.05, 0.95, f"R\u00b2 value: {r_squared:.3f}", transform=ax.transAxes, ha='left', va='top')
    
    #Fit a straight line
    slope, intercept = np.polyfit(data_tmwm_values, ds_plot_values, deg=1)
    #Add the line to the plot, preserving the x and y ranges of the data
    x_low, x_high, y_low, y_high = ax.axis() #Get axis limits
    ax.plot([x_low, x_high], slope * np.array([x_low, x_high]) + intercept, 'r--', label='linear fit of data')
    
    #add identity line, in case that is what you wanted
    lim_low = min(x_low, y_low)
    lim_high = max(x_high, y_high)
    ax.plot([lim_low, lim_high], [lim_low, lim_high], '-k', linewidth=2, label='y=x identity line')
    
    #add legend for a plot, to clarify what the lines represent
    if ax is axs[0, 1]: ax.legend(loc='upper right') 
    
    #optional - clip limits to remove some padding
    ax.axis([lim_low, lim_high, lim_low, lim_high])
    
plt.tight_layout()
plt.show()
© www.soinside.com 2019 - 2024. All rights reserved.