Python / Numpy优化

问题描述 投票:1回答:4

我有一个简单的python类,大约有40行计算,此后给出一个用例示例,它执行简单的计算(基于密度之间的L2距离进行独立测试),并且计算需要花费大量时间只有100点和100升压。这是代码和一些测试数据:

import numpy as np

class IndependenceTesting:


    def __init__(self,data,a,b,dim_to_test,number_of_simulation = 1000):

        # rescale the data :
        self.i = dim_to_test
        self.data = (data - a)/(b-a)
        self.d = self.data.shape[1]
        self.n = self.data.shape[0]
        self.N = number_of_simulation

        self.data_restricted = np.hstack((self.data[:,:(self.i-1)],self.data[:,(self.i+1):]))
        self.emp_cop_restricted = np.array([np.mean(np.array([np.sum(dat <= u) for dat in self.data_restricted]) == self.d - 1) for u in self.data_restricted])

    def simulated_dataset(self):
        unif = np.random.uniform(size=(self.n,1))
        before = [x for x in range(self.d) if x < self.i]
        after = [x for x in range(self.d) if x > self.i]
        return np.hstack((self.data[:,before],unif,self.data[:,after]))

    def mse(self, data):
        emp_cop = np.array([np.mean(np.array([np.sum(dat <= u) for dat in data]) == self.d) for u in self.data])
        return ((emp_cop - self.emp_cop_restricted)**2).mean()

    def mse_distribution(self):
        return np.array([self.mse(self.simulated_dataset()) for i in np.arange(self.N)])

    def mse_observed(self):
        return self.mse(self.data)

    def quantile(self):
        return np.mean(self.mse_distribution() < self.mse_observed())

    def p_value(self):
        return 1 - self.quantile()

# clayton copula exemple
points = np.array([[0.56129339, 0.99710045, 0.57646982],
       [0.12256328, 0.17201513, 0.12885428],
       [0.08511945, 0.11828913, 0.10965346],
       [0.98324131, 0.95728269, 0.92776529],
       [0.54921155, 0.39785825, 0.31361901],
       [0.99487892, 0.63916092, 0.79895483],
       [0.50433754, 0.56999504, 0.60091257],
       [0.92823054, 0.93214344, 0.89725172],
       [0.17751366, 0.18346635, 0.20097246],
       [0.51466364, 0.63436169, 0.46611089],
       [0.25800664, 0.28831929, 0.2903953 ],
       [0.20481173, 0.15871781, 0.15857803],
       [0.31187595, 0.24635342, 0.24171054],
       [0.93662273, 0.80126302, 0.90160681],
       [0.34507788, 0.28888433, 0.30064778],
       [0.81832302, 0.84296836, 0.73139211],
       [0.90751759, 0.74184158, 0.60553314],
       [0.38432821, 0.28571601, 0.22660958],
       [0.47439066, 0.71614234, 0.54718021],
       [0.19106315, 0.31102177, 0.18200903],
       [0.38445433, 0.53108707, 0.35387428],
       [0.77625631, 0.98215295, 0.7751224 ],
       [0.52178207, 0.60999481, 0.45028018],
       [0.2446548 , 0.22270593, 0.30778265],
       [0.62656838, 0.68516045, 0.49434858],
       [0.04573006, 0.03194788, 0.04361497],
       [0.09852491, 0.09004012, 0.08412001],
       [0.11361961, 0.10879038, 0.11352351],
       [0.86116076, 0.92607349, 0.98481143],
       [0.47235565, 0.89094039, 0.52014104],
       [0.32994434, 0.38757998, 0.48919507],
       [0.0052988 , 0.00701797, 0.00637456],
       [0.6230293 , 0.48457337, 0.73184841],
       [0.8039672 , 0.78400854, 0.76272398],
       [0.4585257 , 0.64504907, 0.42333538],
       [0.86565877, 0.89902376, 0.75903263],
       [0.96763817, 0.883972  , 0.99965508],
       [0.72431971, 0.86391135, 0.73501178],
       [0.99153281, 0.98536847, 0.93416086],
       [0.11746542, 0.1142617 , 0.09463402],
       [0.86322008, 0.79150614, 0.48112103],
       [0.031247  , 0.03196072, 0.02701867],
       [0.44120581, 0.48729271, 0.4607829 ],
       [0.01393345, 0.01400763, 0.01567294],
       [0.24365903, 0.20966226, 0.218757  ],
       [0.94584172, 0.94507558, 0.98623726],
       [0.79201305, 0.65503713, 0.79137242],
       [0.06040952, 0.04573984, 0.04640926],
       [0.5673345 , 0.27567432, 0.35234249],
       [0.15860006, 0.12212839, 0.15206467],
       [0.00826576, 0.00407989, 0.00479213],
       [0.72549979, 0.70557491, 0.60543315],
       [0.83039818, 0.76500639, 0.89549151],
       [0.6844257 , 0.81317716, 0.74480599],
       [0.36904583, 0.41081094, 0.36072341],
       [0.14211919, 0.14508685, 0.11253501],
       [0.85139993, 0.86351303, 0.9571894 ],
       [0.72638876, 0.92343587, 0.67884759],
       [0.26816568, 0.22169953, 0.28666315],
       [0.04672121, 0.06183976, 0.09154045],
       [0.81235354, 0.61478793, 0.76379907],
       [0.3562006 , 0.2863009 , 0.31200338],
       [0.42761726, 0.40890689, 0.53401233],
       [0.66337324, 0.96621491, 0.86041736],
       [0.55199335, 0.49320256, 0.43633604],
       [0.80474216, 0.72338883, 0.80206245],
       [0.10724037, 0.11511572, 0.09207419],
       [0.36170945, 0.21664901, 0.20827803],
       [0.9831956 , 0.93518925, 0.89061586],
       [0.10740562, 0.10503344, 0.12320474],
       [0.67589713, 0.65032996, 0.69570242],
       [0.07020206, 0.04963921, 0.06650148],
       [0.4841555 , 0.68809898, 0.65333047],
       [0.60416479, 0.74849448, 0.90509825],
       [0.59250114, 0.71818894, 0.52021291],
       [0.64724464, 0.91296217, 0.96050912],
       [0.75206371, 0.83658298, 0.74361849],
       [0.7338096 , 0.58894243, 0.68243507],
       [0.63778258, 0.79158918, 0.69136578],
       [0.73200902, 0.91405125, 0.81908408],
       [0.15349378, 0.19096759, 0.18099441],
       [0.53616182, 0.51364115, 0.49836299],
       [0.60663723, 0.66756579, 0.66600087],
       [0.72565001, 0.84115262, 0.76362573],
       [0.65200849, 0.86601501, 0.80996763],
       [0.02593363, 0.03604641, 0.05726403],
       [0.39141485, 0.31616432, 0.36365569],
       [0.64372213, 0.53823589, 0.88647631],
       [0.79079997, 0.74427728, 0.67554193],
       [0.07105107, 0.08504079, 0.09113675],
       [0.82765688, 0.7680246 , 0.93645974],
       [0.42258547, 0.46685121, 0.46316008],
       [0.08749291, 0.09122353, 0.10884091],
       [0.93644383, 0.81629942, 0.70997887],
       [0.92635455, 0.95107457, 0.99150588],
       [0.05725108, 0.03565845, 0.03288627],
       [0.11064689, 0.11070949, 0.11499569],
       [0.93098314, 0.98552576, 0.93522353],
       [0.91617665, 0.8137873 , 0.71928403],
       [0.93477362, 0.87389527, 0.87646188]])
points[:,2] = 1 - points[:,2]
points = np.concatenate((points,np.array(np.random.uniform(size=points.shape[0])).reshape((points.shape[0],1))),axis=1)


p_values = [IndependenceTesting(points,np.repeat(0,4),np.repeat(1,4),dim_to_test = i,number_of_simulation= 100).p_value() for i in np.arange(points.shape[1])]

print(p_values)

花费最多时间的那一行可能是emp_cop函数中mse的计算。

您认为此代码可以优化吗?我是python的新手。

谢谢!

python numpy micro-optimization
4个回答
2
投票

通常,python中的列表理解速度很慢,而broadcasting则效率更高。在您的特定情况下,np.array([np.sum(dat <= u) for dat in data])需要很多时间,可以用np.sum(data <= u, axis=1)代替。在我的实验中,这大大提高了速度。您可以通过用广播语句替换其他列表理解来进一步提高性能。


2
投票

我在另一个答案中提到了使用numba以及广播的解决方案。该代码必须稍作重组。 numba是否更快取决于确切的实验次数和data的形状。我还做了一个变体,它使用了numba不支持的稍微不同的广播方法(由于将双循环而不是单循环向量化,因此是先进的。) 'method'关键字可以使用三种方法。


0
投票

我认为hstack会导致代码性能下降,如果您的观点很少,请改为测试np.delete(data,0,i)。


0
投票

这是@ user2653663 Numba方法的注释。如果您想真正有效地使用Numba,建议编写简单的循环,而在Numpy中,您将尝试避免任何显式循环。

© www.soinside.com 2019 - 2024. All rights reserved.