1. 背景
笔试时,遇到一个算法题:差不多是 在n个不同的数中随机取出不重复的m个数。洗牌算法是将原来的数组进行打散,使原数组的某个数在打散后的数组中的每个位置上等概率的出现,刚好可以解决该问题。
2. 洗牌算法
由抽牌、换牌和插牌衍生出三种洗牌算法,其中抽牌和换牌分别对应Fisher-Yates Shuffle和Knuth-Durstenfeld Shhuffle算法。
2.1 Fisher-Yates Shuffle算法
最早提出这个洗牌方法的是 Ronald A. Fisher 和 Frank Yates,即 Fisher–Yates Shuffle,其基本思想就是从原始数组中随机取一个之前没取过的数字到新的数组中,具体如下:
1. 初始化原始数组和新数组,原始数组长度为n(已知);
2. 从还没处理的数组(假如还剩k个)中,随机产生一个[0, k)之间的数字p(假设数组从0开始);
3. 从剩下的k个数中把第p个数取出;
4. 重复步骤2和3直到数字全部取完;
5. 从步骤3取出的数字序列便是一个打乱了的数列。
下面证明其随机性,即每个元素被放置在新数组中的第i个位置是1/n(假设数组大小是n)。
证明:一个元素m被放入第i个位置的概率P = 前i-1个位置选择元素时没有选中m的概率 * 第i个位置选中m的概率,即
#define N 10 #define M 5 void Fisher_Yates_Shuffle(vector<int>& arr,vector<int>& res) { srand((unsigned)time(NULL)); int k; for (int i=0;i<M;++i) { k=rand()%arr.size(); res.push_back(arr[k]); arr.erase(arr.begin()+k); } }
时间复杂度为O(n*n),空间复杂度为O(n).
2.2 Knuth-Durstenfeld Shuffle
Knuth 和 Durstenfeld 在Fisher 等人的基础上对算法进行了改进,在原始数组上对数字进行交互,省去了额外O(n)的空间。该算法的基本思想和 Fisher 类似,每次从未处理的数据中随机取出一个数字,然后把该数字放在数组的尾部,即数组尾部存放的是已经处理过的数字。
算法步骤为:
1. 建立一个数组大小为 n 的数组 arr,分别存放 1 到 n 的数值;
2. 生成一个从 0 到 n - 1 的随机数 x;
3. 输出 arr 下标为 x 的数值,即为第一个随机数;
4. 将 arr 的尾元素和下标为 x 的元素互换;
5. 同2,生成一个从 0 到 n - 2 的随机数 x;
6. 输出 arr 下标为 x 的数值,为第二个随机数;
7. 将 arr 的倒数第二个元素和下标为 x 的元素互换;
……
如上,直到输出 m 个数为止
该算法是经典洗牌算法。它的proof如下:
对于arr[i],洗牌后在第n-1个位置的概率是1/n(第一次交换的随机数为i)
在n-2个位置概率是[(n-1)/n] * [1/(n-1)] = 1/n,(第一次交换的随机数不为i,第二次为arr[i]所在的位置(注意,若i=n-1,第一交换arr[n-1]会被换到一个随机的位置))
在第n-k个位置的概率是[(n-1)/n] * [(n-2)/(n-1)] *...* [(n-k+1)/(n-k+2)] *[1/(n-k+1)] = 1/n
(第一个随机数不要为i,第二次不为arr[i]所在的位置(随着交换有可能会变)……第n-k次为arr[i]所在的位置).
void Knuth_Durstenfeld_Shuffle(vector<int>&arr) { for (int i=arr.size()-1;i>=0;--i) { srand((unsigned)time(NULL)); swap(arr[rand()%(i+1)],arr[i]); } }
时间复杂度为O(n),空间复杂度为O(1),缺点必须知道数组长度n.
原始数组被修改了,这是一个原地打乱顺序的算法,算法时间复杂度也从Fisher算法的 O(n2)提升到了O(n)。由于是从后往前扫描,无法处理不知道长度或动态增长的数组。
2.3 Inside-Out Algorithm
Knuth-Durstenfeld Shuffle 是一个内部打乱的算法,算法完成后原始数据被直接打乱,尽管这个方法可以节省空间,但在有些应用中可能需要保留原始数据,所以需要另外开辟一个数组来存储生成的新序列。
Inside-Out Algorithm 算法的基本思思是从前向后扫描数据,把位置i的数据随机插入到前i个(包括第i个)位置中(假设为k),这个操作是在新数组中进行,然后把原始数据中位置k的数字替换新数组位置i的数字。其实效果相当于新数组中位置k和位置i的数字进行交互。
如果知道arr的lengh的话,可以改为for循环,由于是从前往后遍历,所以可以应对arr[]数目未知的情况,或者arr[]是一个动态增加的情况。
证明如下:
原数组的第 i 个元素(随机到的数)在新数组的前 i 个位置的概率都是:(1/i) * [i/(i+1)] * [(i+1)/(i+2)] *...* [(n-1)/n] = 1/n,(即第i次刚好随机放到了该位置,在后面的n-i 次选择中该数字不被选中)。
原数组的第 i 个元素(随机到的数)在新数组的 i+1 (包括i + 1)以后的位置(假设是第k个位置)的概率是:(1/k) * [k/(k+1)] * [(k+1)/(k+2)] *...* [(n-1)/n] = 1/n(即第k次刚好随机放到了该位置,在后面的n-k次选择中该数字不被选中)。 void Inside_Out_Shuffle(const vector<int>&arr,vector<int>& res)
{
res.assign(arr.size(),0);
copy(arr.begin(),arr.end(),res.begin());
int k;
for (int i=0;i<arr.size();++i)
{
srand((unsigned)time(NULL));
k=rand()%(i+1);
res[i]=res[k];
res[k]=arr[i];
}
}
时间复杂度为O(n),空间复杂度为O(n).
2.4 蓄水池抽样
从N个元素中随机等概率取出k个元素,N长度未知。它能够在o(n)时间内对n个数据进行等概率随机抽取。如果数据集合的量特别大或者还在增长(相当于未知数据集合总量),该算法依然可以等概率抽样.
伪代码:
Init : a reservoir with the size: k for i= k+1 to N M=random(1, i); if( M < k) SWAP the Mth value and ith value end for
上述伪代码的意思是:先选中第1到k个元素,作为被选中的元素。然后依次对第k+1至第N个元素做如下操作:
每个元素都有k/x的概率被选中,然后等概率的(1/k)替换掉被选中的元素。其中x是元素的序号。
proof:
每次都是以 k/i 的概率来选择
例: k=1000的话, 从1001开始作选择,1001被选中的概率是1000/1001,1002被选中的概率是1000/1002,与我们直觉是相符的。
接下来证明:
假设当前是i+1, 按照我们的规定,i+1这个元素被选中的概率是k/i+1,也即第 i+1 这个元素在蓄水池中出现的概率是k/i+1
此时考虑前i个元素,如果前i个元素出现在蓄水池中的概率都是k/i+1的话,说明我们的算法是没有问题的。
对这个问题可以用归纳法来证明:k < i <=N
1.当i=k+1的时候,蓄水池的容量为k,第k+1个元素被选择的概率明显为k/(k+1), 此时前k个元素出现在蓄水池的概率为 k/(k+1), 很明显结论成立。
2.假设当 j=i 的时候结论成立,此时以 k/i 的概率来选择第i个元素,前i-1个元素出现在蓄水池的概率都为k/i。
证明当j=i+1的情况:
即需要证明当以 k/i+1 的概率来选择第i+1个元素的时候,此时任一前i个元素出现在蓄水池的概率都为k/(i+1).
前i个元素出现在蓄水池的概率有2部分组成, ①在第i+1次选择前得出现在蓄水池中,②得保证第i+1次选择的时候不被替换掉
①.由2知道在第i+1次选择前,任一前i个元素出现在蓄水池的概率都为k/i
②.考虑被替换的概率:
首先要被替换得第 i+1 个元素被选中(不然不用替换了)概率为 k/i+1,其次是因为随机替换的池子中k个元素中任意一个,所以不幸被替换的概率是 1/k,故
前i个元素(池中元素)中任一被替换的概率 = k/(i+1) * 1/k = 1/i+1
则(池中元素中)没有被替换的概率为: 1 - 1/(i+1) = i/i+1
综合① ②,通过乘法规则
得到前i个元素出现在蓄水池的概率为 k/i * i/(i+1) = k/i+1
故证明成立
如果m被选中,则随机替换水库中的一个对象。最终每个对象被选中的概率均为k/n,证明如下:
证明:第m个对象被选中的概率=选择m的概率*(其后元素不被选择的概率+其后元素被选择的概率*不替换第m个对象的概率),即
void Reservoir_Sampling(vector<int>& arr) { int k; for (int i=M;i<arr.size();++i) { srand((unsigned)time(NULL)); k=rand()%(i+1); if (k<M) swap(arr[k],arr[i]); } }
因此,蓄水池抽样因为不需知道n的长度,可用于机器学习的数据集的划分,等概率随机抽样分为测试集和训练集。
Leetcode 例题:
3.1. Linked List Random Node
Given a singly linked list, return a random node's value from the linked list. Each node must have the same probability of being chosen.
Follow up:
What if the linked list is extremely large and its length is unknown to you? Could you solve this efficiently without using extra space?
Example:
// Init a singly linked list [1,2,3]. ListNode head = new ListNode(1); head.next = new ListNode(2); head.next.next = new ListNode(3); Solution solution = new Solution(head); // getRandom() should return either 1, 2, or 3 randomly. Each element should have equal probability of returning. solution.getRandom();
利用蓄水池采样原理,无需事先计算list长度即可求解,具体代码如下:
public class Solution { Random r; ListNode head; public Solution(ListNode head) { this.r = new Random(); this.head = head; } public int getRandom() { int count = 1; ListNode nodeVal = head; ListNode curr = head; while (curr != null) { if (r.nextInt(count++) == 0) { nodeVal = curr; } curr = curr.next; } return nodeVal.val; } }
3.2. Random Pick Index
Given an array of integers with possible duplicates, randomly output the index of a given target number. You can assume that the given target number must exist in the array.
Note:
The array size can be very large. Solution that uses too much extra space will not pass the judge.
Example:
int[] nums = new int[] {1,2,3,3,3}; Solution solution = new Solution(nums); // pick(3) should return either index 2, 3, or 4 randomly. Each index should have equal probability of returning. solution.pick(3); // pick(1) should return 0. Since in the array only nums[0] is equal to 1. solution.pick(1);
若采用常规算法,我们需要用HashMap把对应的数值及所处的位置一一关联,这样则违背了题目的空间限制。使用蓄水池采样则避免了这个问题。具体代码如下:
public class Solution { int[] nums; public Solution(int[] nums) { this.nums = nums; } public int pick(int target) { int index = -1; int count = 1; Random random = new Random(); for (int i = 0; i < nums.length; i++) { if (nums[i] == target) { if (random.nextInt(count++) == 0) { index = i; } } } return index; } }
首先我们取到第一个数(暂时取的最后要不要还不一定呢),然后对第二个数以1/2的概率来确定是否 用第二个数来替换他,然后对第二个数以1/3的概率来确定是否用第三个数来替换他。。。。一直这样下去直到第n个数。经过上面的这个过程我们发现每个数取到的概率都变成了(1/n)。证明如下:
总结起来就是一句话每个数取到的概率等于取到该数且取不到该数后面所有数的概率。如:取到第10个数的概率等于取到第十个数且取不到第11到第n个数的概率现在我们回到较复杂的情况,也就是如何在一个N个数(开始不知道N是几)中随机取M个数。其实思想是一样的,就是先取出前M个,然后对后面的开始每个以(k/(i))的概率进行替换,这样我们得到的就是所要的结果,证明如下:
import random import copy def reservoirSampling(seq, k): localSeq = copy.deepcopy(seq) N = len(localSeq) for i in xrange(k, N): M = int(random.uniform(0, i)) if M < k : temp = copy.deepcopy(localSeq[M]) localSeq[M] = copy.deepcopy(localSeq[i]) localSeq[i] = temp return localSeq[0:k] def main(): a = [4,5,6,3,4,7,7,4,3,3,2,4,5,5,6,9,5,4,3,45,3,23,44,55,33,5,8] k = 5 print reservoirSampling(a, k) if __name__ == '__main__': main()
代码总结:
package Random; import java.util.*; public class RandomMethods { /**问题一*/ // 给你一个数组,设计一个既高效又公平的方法随机打乱这个数组(此题和洗牌算法的思想一致) void swap(int[] arr, int i, int j ){ int t = arr[i]; arr[i] = arr[j]; arr[j] = t; } void shuffle_dfs(int[] arr, int n){ if(n <= 1 ){ return; } Random random = new Random(); int t = random.nextInt(n); swap(arr,n-1, t); shuffle_dfs(arr, n-1); } void shuffle(int[] arr , int n){ while(n>1){ Random random = new Random(); int t = random.nextInt(n); swap(arr, n-1, t); n--; } } /**问题2**/ // n已知 //快速生成10亿个不重复的18位随机数的算法(从n个数中生成m个不重复的随机数) //假设从-n这n个数中生成m个不重复的数,且n小于int的表示范围 //总体思想是一开始每个数被选中的概率是m/n,于是随机一个数模n如果余数小于m则输出该数,同时m减 //否则继续扫描,以后的每个数被选中的概率都是m/(n-i) /*遍历第1个数字时有m/n的概率进行选择,如果选择了第1个数字, 则第2个数字被选择的概率调整为(m-1)/(n-1),如果没选择第1个数字, 则第2个数字被选择的概率为m/(n-1)。即遍历到第i个数字的时候, 如果此时已经选择了k个,则以(m-k)/(n-i+1)的概率决定是否要选择当前的第i个数字。 这样可以保证每次都能够保证在剩下的数字中能选择适当的数使得总体选择的数字是m个。 比如,如果前面已经随机了m个,则后面随机的概率就变为0。如果前面一直都没随机到数字, 则后面随机到的概率就会接近1。最终得到的结果始终精确地是m个数字。 */ void random_generate(int n,int m){ int i =1; while(n-i > m ){ Random random = new Random(); int t = random.nextInt(n-i); if(t < m){ System.out.println(i); m -- ; } i++; } while( ++i <= n) System.out.println(i); } void swap(ArrayList<Integer> arr, int i, int j ){ int t = arr.get(i); arr.set(i, arr.get(j)); arr.set(j,t); } // 蓄水池抽样算法 void Reservoir_Sampling(ArrayList<Integer> arr ,int K) { for (int i=K+1;i<arr.size();++i) { Random random = new Random(); int M = random.nextInt(i+1); if (M < K) swap(arr,M, i); } for(int i=0; i<K;i++){ System.out.print(arr.get(i)+" "); } System.out.println(); } }