我们用以下方法计算百万以上float型数据的标准偏差,以估计各个方法的计算性能:
- 原始python
- numpy
- cython
- c(由cython调用)
python 原始方法:
1 # File: StdDev.py 2 3 import math 4 5 def pyStdDev(a): 6 mean = sum(a) / len(a) 7 return math.sqrt((sum(((x - mean)**2 for x in a)) / len(a)))
引入numpy对象:
1 # File: StdDev.py 2 3 import numpy as np 4 5 def npStdDev(a): 6 return np.std(a)
简单cython代码:
# File: cyStdDev.pyx import math def cyStdDev(a): m = a.mean() w = a - m wSq = w**2 return math.sqrt(wSq.mean())
numpy优化后的cython:
# File: cyStdDev.pyx cdef extern from "math.h": double sqrt(double m) from numpy cimport ndarray cimport numpy as np cimport cython @cython.boundscheck(False) def cyOptStdDev(ndarray[np.float64_t, ndim=1] a not None): cdef Py_ssize_t i cdef Py_ssize_t n = a.shape[0] cdef double m = 0.0 for i in range(n): m += a[i] m /= n cdef double v = 0.0 for i in range(n): v += (a[i] - m)**2 return sqrt(v / n)
最后cython调用”c”代码:
# File: cyStdDev.pyx cdef extern from "std_dev.h": double std_dev(double *arr, size_t siz) def cStdDev(ndarray[np.float64_t, ndim=1] a not None): return std_dev(<double*> a.data, a.size)
“c”代码定义在“std_dev.h”:
1 #include <stdlib.h> 2 double std_dev(double *arr, size_t siz);
在“std_dev.c”实现:
#include <math.h> #include "std_dev.h" double std_dev(double *arr, size_t siz) { double mean = 0.0; double sum_sq; double *pVal; double diff; double ret; pVal = arr; for (size_t i = 0; i < siz; ++i, ++pVal) { mean += *pVal; } mean /= siz; pVal = arr; sum_sq = 0.0; for (size_t i = 0; i < siz; ++i, ++pVal) { diff = *pVal - mean; sum_sq += diff * diff; } return sqrt(sum_sq / siz); }
分别测量其运行时间:
# Pure Python python3 -m timeit -s "import StdDev; import numpy as np; a = [float(v) for v in range(1000000)]" "StdDev.pyStdDev(a)" # Numpy python3 -m timeit -s "import StdDev; import numpy as np; a = np.arange(1e6)" "StdDev.npStdDev(a)" # Cython - naive python3 -m timeit -s "import cyStdDev; import numpy as np; a = np.arange(1e6)" "cyStdDev.cyStdDev(a)" # Optimised Cython python3 -m timeit -s "import cyStdDev; import numpy as np; a = np.arange(1e6)" "cyStdDev.cyOptStdDev(a)" # Cython calling C python3 -m timeit -s "import cyStdDev; import numpy as np; a = np.arange(1e6)" "cyStdDev.cStdDev(a)"
结果:
方法 | 运行时间(ms) | python做基准 | numpy做基准 |
python | 183 | 1倍 | 0.03倍 |
numpy | 5.97 | 31 | 1 |
cython | 7.76 | 24 | 0.8 |
cython + numpy | 2.18 | 84 | 2.7 |
调用c | 2.22 | 82 | 2.7 |
总结:
- numpy优化速度很高,相比于python
- cython 在非优化状态下居然跟numpy性能差不多,优秀
- 直接手写c语言是性能很高的,但还是不如cython+numpy,大爷还是厉害
=============================================
qsy 23 may 2019