如何使用星号在热图中重新定位星号

问题描述 投票:0回答:1

如何更改星号在热图中的位置?我希望星号位于数字所在正方形的右上角附近,而不是数字上方。将其放置在角的内侧、靠近角的位置而不是相关角的上方或外侧,可以提高可读性。可以做吗?

from string import ascii_letters
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from scipy.stats import pearsonr

sns.set_theme(style="white")
rs = np.random.RandomState(33)
df = pd.DataFrame(data=rs.normal(size=(100, 26)),
                 columns=list(ascii_letters[26:]))

# Initialize an empty DataFrame for annotations
annotations = pd.DataFrame(index=df.columns, columns=df.columns)

# Recompute the correlation matrix and generate annotations including p-values
correlations = df.corr()

for col1 in df.columns:
    for col2 in df.columns:
        if col1 != col2:  # Avoid calculating pearsonr for identical columns
            corr, p = pearsonr(df[col1], df[col2])
            annotation = f"{corr:.2f}"
            if p < 0.05:
                annotation += "*"
            annotations.loc[col1, col2] = annotation
        else:
            annotations.loc[col1, col2] = f"{1:.2f}"  # Diagonal elements are self-correlations

# Custom colormap
cmap = sns.diverging_palette(250, 10, as_cmap=True)

# Mask for the upper triangle
mask = np.triu(np.ones_like(correlations, dtype=bool))

plt.figure(figsize=(20, 15))
sns.heatmap(correlations, mask=mask, cmap=cmap, vmax=1, center=0, vmin=-1,
            square=True, linewidths=.5, cbar_kws={"shrink": .5}, annot=annotations, fmt='')

plt.title('Coraletion Matrix')
plt.show()
python matplotlib seaborn visualization
1个回答
0
投票

您可以直接在指定位置写

text

from itertools import product
from string import ascii_uppercase

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from scipy.stats import pearsonr

rs = np.random.RandomState(33)

N = 10
columns = list(ascii_uppercase[:N])

df = pd.DataFrame(data=rs.normal(size=(100, N)), columns=columns)

# Initialize an empty DataFrame for annotations
correlations = pd.DataFrame(np.full((N, N), np.nan), index=columns, columns=columns)
p_values = pd.DataFrame(np.full((N, N), np.nan), index=columns, columns=columns)

# Symmetric matrix, and diagonal 1, so can skip
for i1, col1 in enumerate(columns):
    for i2, col2 in enumerate(columns[i1 + 1 :], start=i1 + 1):
        corr, p = pearsonr(df[col1], df[col2])
        correlations.loc[col2, col1] = corr
        p_values.loc[col2, col1] = p

# Figure
sns.set_theme(style="white")
cmap = sns.diverging_palette(250, 10, as_cmap=True)
fig, ax = plt.subplots(figsize=(10, 8))

sns.heatmap(
    correlations,
    cmap=cmap,
    vmax=1,
    center=0,
    vmin=-1,
    square=True,
    linewidths=0.5,
    cbar_kws={"shrink": 0.5},
    annot=True,
    fmt=".2f",
    annot_kws=dict(fontsize=9),
    ax=ax,
)

# Add an asterisk where the p-value is less than 0.05
for i, j in product(range(correlations.shape[0]), range(correlations.shape[1])):
    if p_values.iloc[i, j] < 0.05:
        ax.text(
            j + 0.95,
            i + 0.05,
            "*",
            ha="right",
            va="top",
            color="red",
            fontsize=20,
            fontweight="bold",
        )

ax.set_title("Correlation Matrix")
plt.show()

© www.soinside.com 2019 - 2024. All rights reserved.