算法:求一个源源不断到来数据中的前K个最大(小)元素?以及第K个最大(小)元素?

1,087 阅读4分钟

问题描述:

假如现在给你一个无序的一连串整数,元素个数不确定,数据量很大,甚至源源不断地到来,但是你需要知道到目前为止的前K个最大(小)元素,以及第K个最大(小)元素。

算法实现:

因为“求前K个最大元素”和“求前K个最小元素”这两个问题的实现思路是一致的,因此下面以“求前K个最大元素”举例分析。

方案一:如果数据量不是很大,可以将所有数据放在数组中排序,然后返回前K个最大元素即可。

方案二:很明显,方案一在数据量特别大时效率会非常低,而且如果数据量不确定且源源不断地到来呢?因此我们可以换一种思路,维护一个固定长度为 K 的数组,然后每次新元素到来时将之插入到合适位置并踢出此时数组中的最小元素。这样数组中维护的永远都是前K个最大元素,而且不管数据量有多大,只要每次把数组中最小元素剔除出去即可。

方案三:进一步思考,我们可以发现方案二也有一个明显缺点,那就是每次都需要查找当前数组中的最小值,而且还要把新值插入到数组中的合适位置(如果新值比最小值大),而这个开销也是比较大的。那么我们有没有办法降低这个开销呢?答案肯定是可以的,那就是使用最小堆这种数据结构(PS:而且堆这种数据结构在存储形式上表现为数组结构,存储成本很低)。由最小堆的定义可知它的第一个元素(根节点)永远是堆中最小值,这样我们就可以很方便地将新值和这个最小值比较,如果新值小于堆中根节点,那么不需要调整,相反则需要进行向下调整,调整的效率为O(log2 K),这样一来总体效率就变成了O(N * log2 K)。可见,这种方案效率很高,而且存储成本也很低,因此我将在下面给出使用该方案实现的示例代码,以供大家参考。

import cn.zifangsky.queue.PriorityQueue;
import org.junit.Test;

import java.util.Arrays;
import java.util.Collection;
import java.util.Random;
import java.util.stream.IntStream;

/**
 * 求很大一个的乱序序列的前K个最小(大)的元素?求第K个最小(大)的元素?
 *
 * @author zifangsky
 * @date 2020/5/14
 * @since 1.0.0
 */
public class Problem_001_TopK {
    /**
     * 测试代码
     */
    @Test
    public void testMethods(){
        //1. 生成 1~100 的整数
        Integer[] arr = IntStream.rangeClosed(1, 100).boxed().toArray(Integer[]::new);

        //2. 将数组的顺序打乱
        shuffle(arr);

        //3. 求前10个最小的元素,以及第10小的元素
        Integer[] arr1 = Arrays.copyOf(arr, arr.length);
        Solution<Integer> solution1 = new Solution<>(PriorityQueue.Mode.MAX, 10);
        solution1.addAll(Arrays.asList(arr1));

        System.out.println(Arrays.toString(solution1.toArray(new Integer[10])));
        System.out.println("数组中第10小的元素:" + solution1.getKth());


        //4. 求前10个最大的元素,以及第10大的元素
        Integer[] arr2 = Arrays.copyOf(arr, arr.length);
        Solution<Integer> solution2 = new Solution<>(PriorityQueue.Mode.MIN, 10);
        solution2.addAll(Arrays.asList(arr2));

        System.out.println(Arrays.toString(solution2.toArray(new Integer[10])));
        System.out.println("数组中第10大的元素:" + solution2.getKth());
    }


    static class Solution<T extends Comparable<? super T>> {
        /**
         * 优先队列
         */
        private PriorityQueue<T> queue;

        /**
         * 前K个最小(大)的元素
         */
        private int k;

        public Solution(PriorityQueue.Mode mode, int k) {
            this.queue = new PriorityQueue<T>(mode, k);
            this.k = k;
        }

        /**
         * 添加一个集合进来
         */
        public void addAll(Collection<? extends T> c){
            for(T e : c){
                this.add(e);
            }
        }

        /**
         * 往队列中添加新值
         */
        public void add(T data){
            if(data == null){
                return;
            }

            //1. 如果当前优先队列中个数不足 K 个,则直接插入到队列
            if(queue.size() < k){
                queue.push(data);
                return;
            }

            //“最大优先队列”或者“最小优先队列”
            PriorityQueue.Mode mode = queue.getMode();
            //队列头节点
            T head = queue.top();

            /**
             * 以下两种情况不更改优先队列中的值
             *     最小优先队列(前K个最大的元素):新值比队列中最小值(头结点)还小
             *     最大优先队列(前K个最小的元素):新值比队列中最大值(头结点)还大
             */
            if(head != null && PriorityQueue.Mode.MIN.equals(mode) && data.compareTo(head) <= 0){
                return;
            }
            if(head != null && PriorityQueue.Mode.MAX.equals(mode) && data.compareTo(head) >= 0){
                return;
            }

            //2. 当触发上述两种情况之外的情况,则使用新值替换原来的头结点,成为TopN之一
            queue.pop();
            queue.push(data);
        }

        /**
         * 返回第K个最小(大)的元素
         */
        public T getKth(){
            return queue.isEmpty() ? null : queue.top();
        }

        public T[] toArray(T[] a){
            return queue.toArray(a);
        }
    }

    /**
     * 洗牌
     */
    private void shuffle(Integer[] arr){
        if(arr == null || arr.length < 1){
            return;
        }
        Random rnd = new Random();

        for(int i = (arr.length -1); i > 0; i--){
            swap(arr, i, rnd.nextInt(i));
        }
    }

    private void swap(Integer[] arr, int i, int j){
        int tmp = arr[i];
        arr[i] = arr[j];
        arr[j] = tmp;
    }

}

示例代码输出如下:

[10, 9, 8, 7, 5, 6, 4, 2, 1, 3]
数组中第10小的元素:10
[91, 94, 92, 97, 95, 93, 96, 99, 100, 98]
数组中第10大的元素:91

需要注意的是,上面的示例代码使用了我自己实现的优先队列(PS:优先队列的本质就是最大堆/最小堆),如需参考该类的实现可以看我的这个项目:gitee.com/zifangsky/D…。当然,大家也可以直接使用JDK中的优先队列(java.util.PriorityQueue)。