问题描述
输入一个整形数组和K,输出数组中前K大的元素们。
解决思路
思路1:排序
如果用快排,平均时间复杂度为O(nlogn),最坏时间复杂度为O(n^2);空间复杂度为O(logn)~O(n);
如果用堆排,时间复杂度为O(nlogn),空间复杂度为O(1).
注意:
Java中Arrays.sort()方法默认实现为归并排序,时间O(nlogn),空间O(n)。
思路2:借用快排的partition函数
较思路1的改进在于,不一定要完全将整个数组进行排序,快排中的partition函数能够保证partition后的元素位置之前的元素均大于等于(或小于等于)该指向元素。
平均时间复杂度为O(n),最坏时间复杂度和快排的一样O(n^2)。
思路3:大数据下的堆排
如果场景为数据量很大,或者甚至是无穷量的数据时,此时可借用堆排的思想。
具体做法为,如果是输出前K大,那么需要维护一个大小为K的最小堆,之后的元素与堆顶元素进行比较,如果更大则进入堆中,再调整堆。
时间复杂度为O(n*logk + k),空间复杂度为O(k)。
程序
public class TopK { // sort public List<Integer> getTopKBySort(int[] nums, int k) { List<Integer> res = new ArrayList<Integer>(); if (nums == null || nums.length == 0 || nums.length < k || k <= 0) { return res; } Arrays.sort(nums); for (int i = 0; i < k; i++) { res.add(nums[i]); } return res; } // partition public List<Integer> getTopKByPartition(int[] nums, int k) { List<Integer> res = new ArrayList<Integer>(); if (nums == null || nums.length == 0 || nums.length < k || k <= 0) { return res; } int part = partition(nums, 0, nums.length - 1); while (true) { if (part == k - 1) { for (int i = 0; i < k; i++) { res.add(nums[i]); } break; } else if (part < k - 1) { part = partition(nums, part + 1, nums.length - 1); } else { part = partition(nums, 0, part - 1); } } return res; } private int partition(int[] nums, int begin, int end) { int low = begin - 1, high = end; int pivot = nums[end]; while (true) { while (low < high && nums[++low] >= pivot) { ; } while (low < high && nums[--high] <= pivot) { ; } if (low >= high) { break; } swap(nums, low, high); } swap(nums, low, end); return low; } private void swap(int[] nums, int low, int high) { int tmp = nums[low]; nums[low] = nums[high]; nums[high] = tmp; } // heap public int[] getTopKByHeap(int[] nums, int k) { if (nums == null || nums.length == 0 || nums.length < k || k <= 0) { return null; } int[] res = new int[k]; for (int i = 0; i < nums.length; i++) { if (i < k) { res[i] = nums[i]; } else if (i == k) { buildMinHeap(res); } else { if (nums[i] > res[0]) { res[0] = nums[i]; fixMaxDown(res, 0); } } } return res; } private void fixMaxDown(int[] heap, int i) { int tmp = heap[i]; int j = 2*i +1; while (j < heap.length) { while (j+1 < heap.length && heap[j+1] < heap[j]) { ++j; } if (tmp<heap[j]) { break; } heap[i] = heap[j]; i = j; j = 2*i + 1; } heap[i] = tmp; } private void buildMinHeap(int[] heap) { for (int i = heap.length/2 - 1; i >= 0; i--) { fixMaxDown(heap, i); } } }