寻找数组中第K大的元素

758 阅读2分钟

问题

Find the kth largest element in an unsorted array. Note that it is the kth largest element in the sorted order, not the kth distinct element.

Example 1:

Input: [3,2,1,5,6,4] and k = 2
Output: 5
Example 2:

Input: [3,2,3,1,2,4,5,5,6] and k = 4
Output: 4
Note: 
You may assume k is always valid, 1 ≤ k ≤ array's length.

tag: Medium

分析

这题最简单的做法是将数组排序,然后直接返回第K大的元素。复杂度为:O(NlogN)。但是,很明显,出题者并不想让我们这么做。

如果对数组排序,算法的复杂度起码是 O(NlogN)。那么如果我们不排序,能不能求出第K大元素呢?答案是可以的,我们知道快速排序中有一个步骤是 partition。它选择一个元素作为枢纽(pivot),将所有小于枢纽的元素放到枢纽的左边,将所有大于枢纽的元素放到枢纽的右边。然后返回枢纽在数组中的位置。那么,关键就在这里了。如果此时返回的枢纽元素在数组中的位置刚好是我们所要求的位置,问题就能得到解决了。

我们先选取一个枢纽元素,从数组的第一个元素开始到最后一个元素结束。将大于枢纽的左边元素和小于枢纽的右边元素交换。然后判断当前枢纽元素的 index 是否为 n - k,如果是则直接返回这个枢纽元素。如果 index 大于 n - k,则我们从枢纽的左边继续递归寻找,如果 index 小于 n - k,则我们从枢纽的右边继续递归寻找。

代码

public class KthLargestElementInArray {

    public int findKthLargest(int[] nums, int k) {
        return findKth(nums, nums.length - k /* 第k大,也就是排序后数组中 index 为 n - k 的元素*/, 0, nums.length - 1);
    }

    public int findKth(int[] nums, int k, int low, int high) {
        if (low >= high) {
            return nums[low];
        }
        int pivot = partition(nums, low, high); // 枢纽元素的 index
        if (pivot == k) {
            return nums[pivot];
        } else if (pivot < k) { // 枢纽元素的 index 小于 k,继续从枢纽的右边部分找
            return findKth(nums, k, pivot + 1, high);
        } else { // 枢纽元素的 index 大于 k,继续从枢纽的左边部分找
            return findKth(nums, k, low, pivot - 1);
        }
    }

    int partition(int[] nums, int low, int high) {
        int i = low;
        int j = high + 1;
        int pivot = nums[low];
        while (true) {
            while (less(nums[++i], pivot)) {
                if (i == high) {
                    break;
                }
            }
            while (less(pivot, nums[--j])) {
                if (j == low) {
                    break;
                }
            }
            if (i >= j) {
                break;
            }
            swap(nums, i, j);
        }
        swap(nums, low, j);
        return j;
    }

    boolean less(int i, int j) {
        return i < j;
    }

    void swap(int[] nums, int i, int j) {
        if (i == j) {
            return;
        }
        int temp = nums[i];
        nums[i] = nums[j];
        nums[j] = temp;
    }
}

复杂度分析

复杂度为:O(N) + O(N / 2) + O(N / 4) + ... 最终算法复杂度为 O(N)