1 #include <iostream> 2 #include <cstdlib> 3 #include <cmath> 4 #include <algorithm> 5 6 using namespace std; 7 8 class Solution { 9 public: 10 double findMedianSortedArrays(int A[], int m, int B[], int n) { 11 double ma = 0; 12 double mb = 0; 13 14 bool empty_a = A == NULL || m < 1; 15 bool empty_b = B == NULL || n < 1; 16 17 if (!empty_a) ma = (A[(m - 1) / 2] + A[m/2]) / 2.0; 18 if (!empty_b) mb = (B[(n - 1) / 2] + B[n/2]) / 2.0; 19 20 if (empty_a && empty_b) { // will this happen ? 21 return 0; 22 } else if (empty_a) { 23 return mb; 24 } else if (empty_b) { 25 return ma; 26 } 27 28 double low = 0, high = 0; 29 30 if (ma > mb) { 31 low = mb, high = ma; 32 } else if (ma < mb) { 33 low = ma, high = mb; 34 } else { 35 return ma; 36 } 37 38 double precise = 0.1; 39 double mv = 0; 40 int total = m + n; 41 int half = total / 2; 42 bool declared = false; 43 while(high - low > precise) { 44 mv = (high + low) / 2.0; 45 int* pa = lower_bound(A, A + m, mv); 46 int* pb = lower_bound(B, B + n, mv); 47 int lh = (pa - A) + (pb - B); 48 49 if (lh < half) { // the median assumed is too small, so increase it 50 low = mv; 51 } else if (lh > half) { // the median assumed is too big, so decrease it 52 high= mv; 53 } else { 54 declared = true; 55 // divided into odd/even case. should re-calculate the mv 56 // for even case median calculated from two adjacent numbers in 57 // the merged array, we assume that one is mmore and the other 58 // is mless (median = (mmore + mless) / 2.0 ) 59 int mmore = 0; 60 // find bigger number to compute median for even case. 61 if (pa == A + m && pb == B + n) { 62 // should not happen; 63 cout<<"[1]should not happen"<<endl; 64 } else if (pa == A + m) { 65 mmore = *pb; 66 } else if (pb == B + n) { 67 mmore = *pa; 68 } else { 69 if (*pa < *pb) { 70 mmore = *pa; 71 } else { 72 mmore = *pb; 73 } 74 } 75 76 // for odd case. the mv is equal to value of mmore 77 if (half * 2 != total) { 78 mv = mmore; 79 break; 80 } 81 82 // find samller number to compute median for even case. 83 pa--, pb--; 84 int mless = 0; 85 if (pa < A && pb < B) { 86 // should not happen 87 cout<<"[2]should not happen"<<endl; 88 } else if (pa < A) { 89 mless = *pb; 90 } else if (pb < B) { 91 mless = *pa; 92 } else { 93 if (*pb > * pa) { 94 mless = *pb; 95 } else { 96 mless = *pa; 97 } 98 } 99 mv = (mless + mmore) / 2.0; 100 break; 101 } 102 } 103 if (declared) { // median value is on the boundary 104 return mv; 105 } 106 if (fabs(mv - ma) < fabs(mv - mb)) { 107 return ma; 108 } else { 109 return mb; 110 } 111 } 112 }; 113 114 int main() { 115 Solution s; 116 int A[] = {1, 1}; 117 int B[] = {1, 2}; 118 int m = sizeof(A) / sizeof(A[0]); 119 int n = sizeof(B) / sizeof(B[0]); 120 121 cout<<s.findMedianSortedArrays(A, m, B, n)<<endl; 122 system("pause"); 123 return 0; 124 }
写得好乱啊, 这个还是二分搜索吧,只不过用来决定选择前半部还是后半部的评价标准变了,由原来的与一个确定常数数比较变为两个变量之间的比较(lh 与 half之间的数量关系),搜索空间由一个数组变为一个数值区间(其实都可以看做解的值域)。230ms+。
题目中提到"The overall run time complexity should be O(log (m+n)).",其实log前面有常数,由于数据是整数,经过32次二分搜索,可以使数值空间降到1以内,再过4次可以降到0.1内。
再用O(n)的简单解法感觉时间上差不多,不知为何
1 class Solution { 2 public: 3 double findMedianSortedArrays(int A[], int m, int B[], int n) { 4 int ia = 0, ib = 0; 5 int it = -1; 6 int im = (m + n - 1) / 2; 7 int val= 0; 8 9 bool empty_a = A == NULL || m < 1; 10 bool empty_b = B == NULL || n < 1; 11 12 while (!empty_a && ia < m && !empty_b && ib < n && it < im) { 13 if (A[ia] < B[ib]) { 14 val = A[ia++]; 15 } else { 16 val = B[ib++]; 17 } 18 ++it; 19 } 20 21 while (!empty_a && ia < m && it < im) { 22 val = A[ia++]; 23 it++; 24 } 25 while (!empty_b && ib < n && it < im) { 26 val = B[ib++]; 27 it++; 28 } 29 if ((m + n) & 1) { 30 return val; 31 } else { 32 int val2 = 0; 33 if ((empty_a || ia >= m) && (empty_b || ib >= n)) { 34 // should not happen 35 } else if (empty_a || ia >= m) { 36 val2 = B[ib]; 37 } else if (empty_b || ib >= n) { 38 val2 = A[ia]; 39 } else { 40 val2 = A[ia] > B[ib] ? B[ib] : A[ia]; 41 } 42 return (val + val2) / 2.0; 43 } 44 } 45 };
在discuss里找到一份log(m+n)的代码:
class Solution { public: double findMedianSortedArrays(int A[], int m, int B[], int n) { int length=m+n; if(length%2)return findkth(A, m, B, n, length/2+1); else return (double(findkth(A, m, B, n, length/2))+findkth(A, m, B, n, length/2+1))/2; } int findkth(int A[],int m,int B[], int n, int k){ if(m>n) return findkth(B, n, A, m,k); if(m==0)return B[k-1]; if(k==1)return A[0]<B[0]?A[0]:B[0]; int pa=k/2<m?k/2:m; int pb=k-pa; if(A[pa-1]==B[pb-1]){return A[pa-1];} if(A[pa-1]<B[pb-1]) return findkth(A+pa, m-pa, B, pb, k-pa); else return findkth(A,pa,B+pb,n-pb,k-pb); } };
花点时间理解:
下面先不考虑k/2>=Na, 及K=1(K=1时比较两数组首元素即可得出)的情况,数组下标从0开始。取第K个数的算法,首先取pa=k/2, pb=k-k/2;这样使得{A[0], A[1]...A[pa-1]}的元素数目加上{B[0], B[1]...B[pb-1]}的元素数目刚好等于k。此时如果:
1. A[pa-1] = B[pb-1],那么很容易知道A[pa-1]或者说B[pb-1]就是第K个数。因为数组是已排序的,且|{A[0], A[1]...A[pa-1]}| + |{B[0], B[1]...B[pb-1]}| = K
2. A[pa-1] < B[pb-1],那么可以认为第K个数肯定不在数组A的[0, pa-1]这个区间内。用反证法可以证明:
假设第K个数存在于A[0...pa-1]中,设其为X,则根据第K个数的含义,其前面必然存在K-1个数小于等于X。但由于X是在A[0...pa-1]中被找到的,而数组A中这样的数最多只有(即A[pa-1]为中位数时):|{A[0], A[1]...A[pa-1]}| - 1= k/2 - 1 < K-1。剩下的数需要从B数组中取,至少需要K - 1 - (K/2 - 1) = K - K/2个数。但由于存在条件A[pa-1] < B[pb-1],B数组中的第K-K/2个数即B[pb-1]要比X大,产生矛盾,故假设不成立。所以第K个数肯定不在数组A的[0, pa-1]这个区间内,此时我们只需要在剩下的区间内搜索就可以了,寻找第K大的元素变为寻找第(K-pa)大的元素(因为我们已经排除了数组A中前pa个元素)
3. A[pa-1] > B[pb-1],这种情况是第二种情况的对称情况。即可以排除{B[0], B[1]...B[pb-1]}这个搜索区间,并继续寻找第(K-pb)大的元素
再来一次:
class Solution { public: double findMedianSortedArrays(vector<int>& nums1, vector<int>& nums2) { int len1 = nums1.size(); int len2 = nums2.size(); int total= len1 + len2; if (total & 0x1) { return findK(&nums1[0], &nums2[0], len1, len2, total / 2 + 1); } else { double lo = findK(&nums1[0], &nums2[0], len1, len2, total / 2); double hi = findK(&nums1[0], &nums2[0], len1, len2, total / 2 + 1); return (lo + hi) / 2; } } int findK(const int* a, const int* b, int na, int nb, int k) { if (nb < na) { return findK(b, a, nb, na, k); } if (na == 0) { return b[k - 1]; } if (k == 1) { return a[0] > b[0] ? b[0] : a[0]; } int pa = k / 2 < na ? k / 2 : na; int pb = k - pa; if (a[pa - 1] == b[pb - 1]) { return a[pa - 1]; } else if (a[pa - 1] < b[pb - 1]) { return findK(a + pa, b, na - pa, nb, k - pa); } else { return findK(a, b + pb, na, nb - pb, k - pb); } } };
不用指针真是烦了好多:
class Solution { public: double findMedianSortedArrays(vector<int>& nums1, vector<int>& nums2) { int n1 = nums1.size(); int n2 = nums2.size(); int total = n1 + n2; if (total & 0x1) { return findk(nums1, 0, n1, nums2, 0, n2, total/2); } return (findk(nums1, 0, n1, nums2, 0, n2, total/2) + findk(nums1, 0, n1, nums2, 0, n2, total/2 - 1)) / 2; } double findk(vector<int>& n1, int s1, int e1, vector<int>& n2, int s2, int e2, int k) { int l1 = e1 - s1; int l2 = e2 - s2; if (l2 < l1) { return findk(n2, s2, e2, n1, s1, e1, k); } if (l1 == 0) { return n2[s2 + k]; } if (k == 0) { return n1[s1] > n2[s2] ? n2[s2] : n1[s1]; } int pa = (k+1)/2 > l1 ? l1 : (k+1)/2; int pb = k+1 - pa; if (n1[s1 + pa - 1] == n2[s2 + pb - 1]) { return n1[pa - 1]; } if (n1[s1 + pa - 1] > n2[s2 + pb - 1]) { return findk(n1, s1, e1, n2, s2 + pb, e2, k - pb); } else { return findk(n1, s1 + pa, e1, n2, s2, e2, k - pa); } } };