问题:
给出一个数组$a[n]$,求第$k$小元素是什么。
解析:
分治思想,将数组五个一组划分,并计算出每组数的中位数。然后把各组中位数的中位数找出。统计数组中小于中位数的个数$num$,有三种情况。
① $num = k$,则中位数就是要查询的数。
② $num > k$,则在小于中位数的集合中查询第$k$小。
③ $num < k$,则在大于中位数的集合中查询第$k – num$小。
设计(核心代码):
1 void insertsort(int a[], int l, int r)//从小到大排序 2 { 3 int i, j, key; 4 for (i = l + 1; i <= r; ++i) 5 { 6 key = a[i]; 7 for (j = i - 1; j >= l && key < a[j]; --j) 8 { 9 a[j + 1] = a[j]; 10 } 11 a[j + 1] = key; 12 } 13 } 14 15 int partition(int a[], int l, int r, int pivot) 16 { 17 int x, i = l - 1, j; 18 for (j = l; j < r; ++j) 19 { 20 if (a[j] == pivot) swap(a[j], a[r]); 21 } 22 x = a[r]; 23 for (j = l; j < r; ++j) 24 { 25 if (a[j] <= x) 26 { 27 ++i; 28 swap(a[i], a[j]); 29 } 30 } 31 swap(a[r], a[i + 1]); 32 return i + 1; 33 } 34 35 int select(int a[], int l, int r, int k) 36 { 37 int group, i, left, right, mid; 38 int pivot, p, lnum; 39 if (r - l + 1 <= 5) 40 { 41 insertsort(a, l, r); 42 return a[l + k - 1]; 43 } 44 group = (r - l + 1 + 5) / 5; 45 for (i = 0; i < group; ++i) 46 { 47 left = l + 5 * i; 48 right = (l + 5 * i + 4) > r ? r : l + 5 * i + 4; 49 mid = (left + right) / 2; 50 insertsort(a, left, right); 51 swap(a[l + i], a[mid]); 52 } 53 pivot = select(a, l, l + group - 1, (group + 1) / 2); 54 p = partition(a, l, r, pivot); 55 lnum = p - l; 56 if (k == lnum + 1) 57 return a[p]; 58 else if (k <= lnum) 59 return select(a, l, p - 1, k); 60 else 61 return select(a, p + 1, r, k - lnum - 1); 62 }
分析:
复杂度:$O(n)$。
源码:
https://github.com/Big-Kelly/Algorithm
1 #include<bits/stdc++.h> 2 #include <set> 3 #include <map> 4 #include <stack> 5 #include <cmath> 6 #include <queue> 7 #include <cstdio> 8 #include <string> 9 #include <vector> 10 #include <cstring> 11 #include <iostream> 12 #include <algorithm> 13 14 #define ll long long 15 #define pll pair<ll,ll> 16 #define pii pair<int,int> 17 #define bug printf("********* ") 18 #define FIN freopen("input.txt","r",stdin); 19 #define FON freopen("output.txt","w+",stdout); 20 #define IO ios::sync_with_stdio(false),cin.tie(0) 21 #define ls root<<1 22 #define rs root<<1|1 23 #define Q(a) cout<<a<<endl 24 25 using namespace std; 26 const int inf = 2e9 + 7; 27 const ll Inf = 1e18 + 7; 28 const int maxn = 1e6 + 5; 29 const int mod = 1e9 + 7; 30 31 ll gcd(ll a, ll b) 32 { 33 return b ? gcd(b, a % b) : a; 34 } 35 36 ll lcm(ll a, ll b) 37 { 38 return a / gcd(a, b) * b; 39 } 40 41 ll read() 42 { 43 ll p = 0, sum = 0; 44 char ch; 45 ch = getchar(); 46 while (1) 47 { 48 if (ch == '-' || (ch >= '0' && ch <= '9')) 49 break; 50 ch = getchar(); 51 } 52 53 if (ch == '-') 54 { 55 p = 1; 56 ch = getchar(); 57 } 58 while (ch >= '0' && ch <= '9') 59 { 60 sum = sum * 10 + ch - '0'; 61 ch = getchar(); 62 } 63 return p ? -sum : sum; 64 } 65 66 void insertsort(int a[], int l, int r)//从小到大排序 67 { 68 int i, j, key; 69 for (i = l + 1; i <= r; ++i) 70 { 71 key = a[i]; 72 for (j = i - 1; j >= l && key < a[j]; --j) 73 { 74 a[j + 1] = a[j]; 75 } 76 a[j + 1] = key; 77 } 78 } 79 80 int partition(int a[], int l, int r, int pivot) 81 { 82 int x, i = l - 1, j; 83 for (j = l; j < r; ++j) 84 { 85 if (a[j] == pivot) swap(a[j], a[r]); 86 } 87 x = a[r]; 88 for (j = l; j < r; ++j) 89 { 90 if (a[j] <= x) 91 { 92 ++i; 93 swap(a[i], a[j]); 94 } 95 } 96 swap(a[r], a[i + 1]); 97 return i + 1; 98 } 99 100 int select(int a[], int l, int r, int k) 101 { 102 int group, i, left, right, mid; 103 int pivot, p, lnum; 104 if (r - l + 1 <= 5) 105 { 106 insertsort(a, l, r); 107 return a[l + k - 1]; 108 } 109 group = (r - l + 1 + 5) / 5; 110 for (i = 0; i < group; ++i) 111 { 112 left = l + 5 * i; 113 right = (l + 5 * i + 4) > r ? r : l + 5 * i + 4; 114 mid = (left + right) / 2; 115 insertsort(a, left, right); 116 swap(a[l + i], a[mid]); 117 } 118 pivot = select(a, l, l + group - 1, (group + 1) / 2); 119 p = partition(a, l, r, pivot); 120 lnum = p - l; 121 if (k == lnum + 1) 122 return a[p]; 123 else if (k <= lnum) 124 return select(a, l, p - 1, k); 125 else 126 return select(a, p + 1, r, k - lnum - 1); 127 } 128 129 int a[maxn]; 130 int n, k; 131 132 int main() 133 { 134 scanf("%d %d", &n, &k); 135 for (int i = 1; i <= n; ++i) scanf("%d", &a[i]); 136 printf("%d ", select(a, 1, n, k)); 137 }