秒懂算法系列 —— 最小 K 个数 & 分治法

992 阅读6分钟

您的点赞和关注是我坚持写作的最大动力,本人正在寻找测试开发的工作机会,欢迎微信联系: gyx764884989

什么是算法范式?

算法范式即算法设计的设计模式,类似软件工程中 GoF 提出的 23 种经典的设计模式,是前人在经历大规模实践后总结的较为通用的解决问题,优化问题的模板,学习和了解算法范式,就可以站在巨人的肩膀上解决问题。常见的算法范式有如下几类:

什么是分治法?

从字面意思理解,分治法即分而治之,就是把一个复杂的问题分成两个或更多的相同或相似的子问题,直到最后子问题可以简单的直接求解,原问题的解即子问题的解的合并。很多算法问题的经典解法都有分治法的身影,比如 Karatsuba 快速乘法算法、快速排序算法和并行算法

题目介绍

输入整数数组 arr ,找出其中最小的 k 个数。例如,输入 4、5、1、6、2、7、3、8 这 8 个数字,则最小的 4 个数字是 1、2、3、4。

示例 1:

输入:arr = [3,2,1], k = 2
输出:[1,2] 或者 [2,1]

示例 2:

输入:arr = [0,1,2,1], k = 1
输出:[0]

思路

要想输出前 K 小的数字列表,比较直观的想法是将整个数组排序,然后输出前 K 小的数字,故我们利用目前最高效的排序算法,快排的思想解决该问题,完成一次完整的快排,所需的时间复杂度为 O(nlogn)。这里就存在一个优化点,题目要求取前 K 个数字,所以快排流程不需要完全执行完,只要判断枢纽元素的位置与 K 的关系即可。按照快排的思想

假设在一次快排的划分后,枢纽元素位于下标 m 处,那么我们可以得到两个关键信息:

  • 这个枢纽元素左边有 m 个数
  • 这是原数组最小的 m 个数,换句话说就是这个 m 个数中最大的那个数,也比数组中其余的数字小

此时我们要做的事情就十分简单了,就要判断这个 m 个数到底是几个数字?是否为题目要求的 K 个数字?这就分为如下的三种情况:

  1. m = k, 巧了不是,我们恰好得到了题目所需的前 K 小的数字,搞定!
  2. m < k, 找到的数字都对,但是还不够,所以我们还要在剩余的数字中继续执行划分操作,找到剩余的数字
  3. m > k, 找多了,有些数字不对,但是正确结果一定都在这个 m 个数字中,需要对这 m 个数字继续执行划分操作

这样我们就把寻找前 K 小个数的问题等价转化为不断的进行划分,直到 m = k 的问题,这种方法也叫快速选择法。最终算法最理想的时间复杂度可以降低到 O(n)

再谈划分(Partition)

上述思路多次提到了划分的操作,如果有同学不熟悉的话我们这里再来回顾一下。划分操作是快排的核心思想。在一次划分完成后我们就可以找到一个分水岭数字,在这个分水岭左边的每一个数字都比分水岭数字小,右边的每一个数字都比分水岭数字大。这个分水岭数字被称为枢纽元素,具体流程可参见下面的图片:

  1. 以数组中最后一个数字作为临时枢纽元素
  2. 设置两个指针,分别置于数组的首尾两个位置,当然这里要排除掉刚刚临时指定的枢纽元素,第一个指针可称为首指针第二个可称为尾指针
  3. 向前移动首指针,如果指向的元素比枢纽元素小,则继续向前移动首指针,如果比枢纽元素大,则暂停移动
  4. 与首指针的移动规则类似,只是方向相反,同样我们这里找到第一个比枢纽元素 的原始,尾指针同样暂停移动
  5. 互换首位指针所指向的元素
  6. 不断重复 3,4,5 步骤,直到首位指针重合
  7. 将重合位置的元素与枢纽元素互换位置
  8. 一次划分即完成

代码

public int[] getLeastNumbers(int[] arr, int k) {
    if (k == 0) {
        return new int[0];
    } else if (arr.length <= k) {
        return arr;
    }

    // 原地不断划分数组
    partitionArray(arr, 0, arr.length - 1, k);

    // 数组的前 k 个数此时就是最小的 k 个数,将其存入结果
    int[] res = new int[k];
    for (int i = 0; i < k; i++) {
        res[i] = arr[i];
    }
    return res;
}

// lo 和 hi 分别是首位指针所指向的元素位置
void partitionArray(int[] arr, int lo, int hi, int k) {
    // 做一次 partition 操作
    int m = partition(arr, lo, hi);
    // 此时数组前 m 个数,就是最小的 m 个数
    if (k == m) {
        // 正好找到最小的 k(m) 个数
        return;
    } else if (k < m) {
        // 最小的 k 个数一定在前 m 个数中,递归划分
        partitionArray(arr, lo, m-1, k);
    } else {
        // 在右侧数组中寻找最小的 k-m 个数, 👇 这里后面会介绍一个细节 👇
        partitionArray(arr, m+1, hi, k);
    }
}

// partition 函数和快速排序中相同,具体可参考快速排序相关的资料
// 代码参考 Sedgewick 的《算法4》
int partition(int[] a, int lo, int hi) {
    int i = lo;
    int j = hi + 1;
    int v = a[lo];
    while (true) {
        while (a[++i] < v) {
            if (i == hi) {
                break;
            }
        }
        while (a[--j] > v) {
            if (j == lo) {
                break;
            }
        }

        if (i >= j) {
            break;
        }
        swap(a, i, j);
    }
    swap(a, lo, j);

    // a[lo .. j-1] <= a[j] <= a[j+1 .. hi]
    return j;
}

/**
 * 交换数组中两个元素的位置
 **/
void swap(int[] a, int i, int j) {
    int temp = a[i];
    a[i] = a[j];
    a[j] = temp;
}

代码细节

  partitionArray(arr, m+1, hi, k);

这里有一个细节,当 m < k 时,我们需要进行第二种类型的划分,即在剩余数组中寻找剩余的 k-m 个最小数,那么你是否有疑问,既然已经找到了 m 个最小数,还差 k-m 个,那么是不是最后一个参数也要传 k-m?其实这里的 k 的含义并不是剩余需要寻找的数字的个数,而是一个相对于整个数组而言,所需的前 K 小的数字的个数,无论你要在数组的哪个范围内进行划分,这个标准不能变,还有一种解释方法是,因为首指针没有从 0 开始计算,而是 m+1 ,如果首指针从 0 开始计算的话,那么最后一个参数可以为 k-m,但是这样不方便进行递归操作。

代码流程图

最理想的 O(n) 复杂度是如何计算得到的?

因为每次划分需要遍历的元素都相当于上次的 1/2, 因此可得到如下公式

O(n) = n + \frac{n}{2}+\frac{n}{4}+\frac{n}{8}+...+\frac{n}{2^k} = 2n

所以时间复杂度是 O(n)

参考