这一章就讲解 numpy.where 函数。
他是三元表达式 的向量化版本。
三元表达式: x if condition else y
import numpy as np
xarr = np.array([1.1, 1.2, 1.3, 1.4, 1.5])
yarr = np.array([2.1, 2.2, 2.3, 2.4, 2.5])
cond = np.array([True, False, True, True, False])
在 cond 中的元素为 True 时,我们取 xarr 中对应的元素值,反之取 yarr 中的元素,可以通过列表推导式来完成,如下
result = [(x if c else y) for x,y,c in zip(xarr, yarr, cond)]
print result
结果
[1.1, 2.2, 1.3, 1.4, 2.5]
这样会产生多个问题,首先如果数组很大的话,速度会很慢,因为所有的工作都是通过解释器解释 Python 代码完成的,其次,但数组是多维时,就无法奏效了。而使用 np.where 时,就可以非常简单的完成。
result = np.where(cond, xarr, yarr)
print(result)
结果
[1.1, 2.2, 1.3, 1.4, 2.5]
where 的第二个和第三个参数并不需要是数组,它们可以是标量。where在数据分析中的一个典型用法是根据一个数组来生成一个新的数组。
比如:假设我有一个随机生成的正态分布的矩阵数据,并且你想将其中的正数都变成2 ,所有负数都变成 -2 ,使用 where就很容易实现
arr = np.random.randn(4,4)
print(arr)
print('-----------')
print(np.where(arr>0, 2, -2))
print('---------')
print(np.where(arr > 0, 2 , arr))
结果
[[-1.36253267 0.7440669 0.64946862 0.36891392]
[-0.01551911 0.01003852 2.32228195 -0.17199506]
[ 0.46897204 -0.15731851 -0.53239796 -0.56446061]
[-0.55566614 -0.27882886 -0.12773451 -0.15585518]]
-----------
[[-2 2 2 2]
[-2 2 2 -2]
[ 2 -2 -2 -2]
[-2 -2 -2 -2]]
---------
[[-1.36253267 2. 2. 2. ]
[-0.01551911 2. 2. -0.17199506]
[ 2. -0.15731851 -0.53239796 -0.56446061]
[-0.55566614 -0.27882886 -0.12773451 -0.15585518]]
正如看到的,第三个参数可以写 原本的数组,就相当于把小于0的元素用原本的数组的对应位置的值对换,也就是说不做处理。
总结:
传递给 np.where 的数组既可以是同等大小的数组,也可以是标量