我对编码还很陌生,所以请耐心等待。 我正在尝试对不同的概率分布进行建模,我想为每个概率分布拟合一个高斯分布,然后找到生成的高斯分布的标准差,然后比较这些标准偏差。
我不确定如何告诉 python 找到最接近我的曲线的高斯分布,因此我们将不胜感激。
这是我到目前为止的代码,我尝试使用 curve_fit,但我真的不知道如何使用,所以我把它拿出来,并且我不确定如何定义高斯以获得最佳拟合。
import numpy as np
import matplotlib.pyplot as plt
from scipy import linalg, special, optimize, stats
from numpy import math
from scipy.optimize import curve_fit
q = 4000
qA = np.arange(0.000001,q+1,1.0)
qB = q - qA
N = 2000
NA = 1000
NB = N - NA
ASA = (qA + NA -1)*np.log(qA + NA -1)-(qA*np.log(qA)) - (NA -1)*np.log(NA -1) #sterling approximate of SA/kb
ASB = (qB + NB -1)*np.log(qB + NB -1)-(qB*np.log(qB)) - (NB -1)*np.log(NB -1) #sterling approximate of SB/kb
TATA_list=[]
for i in range(1,len(qA)-1):
TATE = (qA[i+1] - qA[i-1]) / (ASA[i+1] - ASA[i-1])
TATA_list.append(TATE)
TBTA_list=[]
for i in range(1,len(qA)-1):
TBTE = (qB[i+1] - qB[i-1]) / (ASB[i+1] - ASB[i-1])
TBTA_list.append(TBTE)
AStot = (q + N -1)*np.log(q + N -1)-(q*np.log(q))-(N-1)*np.log(N-1)
Paprox = np.exp(ASA + ASB - AStot)
SUMPX = sum(Paprox[:-1])
NormPAprox = Paprox/SUMPX
plt.plot(qA/q, NormPAprox)
这是一个使用
sklearn
将高斯拟合到某些数据的示例,并进行了一些可视化。数据只是测量值列表。您可以拟合更多高斯分布,并且它们可以是多维的。
创建一些模拟数据并绘制它:
import numpy as np
import matplotlib.pyplot as plt
#Create some test data
np.random.seed(0)
measurements = np.random.randn(100)
#Plot the data
plt.scatter(range(0, len(measurements)), measurements, marker='s', c='tab:green', label='raw data')
plt.xlabel('sample number')
plt.ylabel('measurement')
plt.legend()
拟合模型并报告结果:
#Fit Gaussian to the data
from sklearn.mixture import GaussianMixture
n_gaussians = 1
gmm = GaussianMixture(n_components=n_gaussians).fit(measurements.reshape(-1, 1))
#Print results
print(f'Fitted {n_gaussians} Gaussians to {len(measurements)} samples:')
for idx in range(n_gaussians):
print(f' Gaussian {idx} mean:', gmm.means_.flatten().round(2))
print(f' Gaussian {idx} var: ', gmm.covariances_.flatten().round(2))
#Overlay the fitted Gaussians
for idx in range(n_gaussians):
mean = gmm.means_[idx]
var = gmm.covariances_[idx]
colour = plt.get_cmap('jet')((idx + 1)/ n_gaussians)
plt.axhline(mean, c=colour, lw=2, ls='-', label=f'Gaussian {idx} mean')
plt.vlines(idx, ymin=mean - var**0.5, ymax=mean + var**0.5, colors=colour, lw=2, label=f'Gaussian {idx} std')
plt.legend(bbox_to_anchor=(1.1, 1));