数据结构之手写 TreeMap

562 阅读14分钟

本篇文章将逐步带你手写一个 TreeMap ,使用Java 语言实现。

1. 迭代改递归练习

首先我们进行迭代改递归练习,帮助我们理解递归。

打印数组中所有元素

使用迭代实现

public static void printArr(int[] arr){
    
    for(int i = 0; i < arr.length; i++){
        System.out.println(arr[i]);
    }
}

使用递归实现:(注意递归要有终止条件)

public static void printArr(int[] arr){
    //传入初始索引
    printArr(arr, 0);
}
​
//定义:打印arr[i] 以及之后的所有元素
private static void printArr(int[] arr, int i) {
    //base case 递归结束条件
    if(i == arr.length){
        return;
    }
    //打印arr[i]
    System.out.println(arr[i]);
    //i++
    //然后打印arr[i + 1]
    printArr(arr, i + 1);
}

倒序打印所有元素?

  1. 根据 printArr(int[] arr, int i) 定义,我们只需要调换相关语句顺序即可:
//i++
//然后打印arr[i + 1]
printArr(arr, i + 1);   
//打印arr[i]
System.out.println(arr[i]);
  1. 改变base case 和初始传入值:
//传入初始索引
printArr(arr, arr.length - 1);
_________________
//base case 递归结束条件
if(i < 0){
    return;
}
//打印arr[i]
System.out.println(arr[i]);
//i--
//然后打印arr[i - 1]
printArr(arr, i - 1);

搜索数组中指定值返回其索引,如果找不到返回-1

使用迭代实现:

public static int search(int[] arr, int target) {
    for(int i = 1; i < arr.length; i++){
        if(arr[i] == target){
            return i;
        }
    }
    return -1;
}

使用递归实现:

public static int search(int[] arr, int target) {
    return search(arr, target, 0);
}
​
//定义:在arr[i]以及arr[i]之后寻找target
private static int search(int[] arr, int target, int i) {
    //base case
    if(i == arr.length){
        return -1;
    }
​
    //base case
    if(arr[i] == target){
        return i;
    }
    
    //当前找不到,那么在i+1之后寻找
    return search(arr, target, i + 1);
}

2. 手写单链表 RecursiveList

之后我们需要手写一个单向链表使用纯递归实现,来为后续实现 TreeMap 打下基础。 底层是一个链表


首先看初始代码

/**
 * 单链表递归实现
 */
public class RecursiveList<E> {
​
    // 单链表链表节点
    private static class Node<E> {
        E val;
        //指向下一个节点
        Node<E> next;
​
        Node(E val) {
            this.val = val;
        }
    }
​
    //头节点
    
    private Node<E> first = null;
​
    //长度
    
    private int size = 0;
​
    public RecursiveList() {
    }
​
    /***** 增 *****/
​
    //在头部插入节点
    public void addFirst(E e) {
    }
​
    public void addLast(E e) {
        
    }
​
    public void add(int index, E e) {
        
    }
​
    /***** 删 *****/
​
    public E removeFirst() {
       
    }
​
    public void removeLast() {
        
    }
​
    public void remove(int index) {
       
    }
​
    /***** 查 *****/
    
    public E get(int index){
        
    }
​
    public E getFirst() {
        
    }
​
    public E getLast() {
      
    }
​
    /***** 改 *****/
​
    public E set(int index, E element) {
        
    }
​
    /***** 其他工具函数 *****/
    public int size() {
        return size;
    }
​
    public boolean isEmpty() {
        return size == 0;
    }
​
    private boolean isElementIndex(int index) {
        return index >= 0 && index < size;
    }
​
    private boolean isPositionIndex(int index) {
        return index >= 0 && index <= size;
    }
​
    /**
     * 检查 index 索引位置是否可以存在元素
     */
    private void checkElementIndex(int index) {
        if (!isElementIndex(index))
            throw new IndexOutOfBoundsException("Index: " + index + ", Size: " + size);
    }
​
    /**
     * 检查 index 索引位置是否可以添加元素
     */
    private void checkPositionIndex(int index) {
        if (!isPositionIndex(index))
            throw new IndexOutOfBoundsException("Index: " + index + ", Size: " + size);
    }
    
    //返回 index 对应的 Node
    //注意:请保证传入的 index 是合法的
    private Node<E> getNode(int index){
        Node<E> p = first;
        for(int i = 0; i < index; i++){
            p = p.next;
        }
        return p;
    }
​
}
​

我们首先实现xxxFirst 相关的方法,因为这些方法不涉及 for 循环。

  1. addFirst 在链表头部添加元素

public void addFirst(E e) {
    //首先做出这个节点
    Node<E> x = new Node<>(e);
    x.next = first;
    first = x;
    size++;
}
  1. removeFirst 删除头部节点
public E removeFirst() {
    //首先判断链表是否为空
    if(isEmpty()){
        throw new NoSuchElementException();
    }
    E deleteVal = first.val;
    // 删除
    first = first.next;
    return deleteVal;
}
  1. getFirst 获取头部节点
public E getFirst() {
    if(isEmpty()){
        throw new NoSuchElementException();
    }
    return first.val;
}

然后我们实现 getNode 方法,根据索引获取对应的 Node 节点

起初我们是通过 for 循环实现的

//返回 index 对应的 Node
//注意:请保证传入的 index 是合法的
private Node<E> getNode(int index){
    Node<E> p = first;
    for(int i = 0; i < index; i++){
        p = p.next;
    }
    return p;
}

那么如何通过递归实现:我们可以仿照我们最初的 search 函数迭代改递归:

// 返回「从 node 开始的第 index 个链表节点」
private Node<E> getNode(Node<E> node, int index) {
    // base case
    if (index == 0) {
        return node;
    }
    // 返回 「从 node.next 开始的第 index - 1 个链表节点」
    return getNode(node.next, index - 1);
}

那么 getNode(int index) 如何写?

private Node<E> getNode(int index){
    //返回从first 开始的第index个节点
    return getNode(first, index);
}

基于 getNode 我们就可以实现 getset 方法了

public E get(int index){
    //检查索引是否越界
    checkElementIndex(index);
    Node<E> node = getNode(index);
    return node.val;
}
​
public E set(int index, E element) {
    checkElementIndex(index);
    Node<E> p = getNode(index);
​
    E oldVal = p.val;
    p.val = element;
​
    return oldVal;
}

然后我们实现 getLast 方法

我们可以通过 getNode(size -1)来获取最后一个节点

public E getLast() {
    if (isEmpty()) {
        throw new NoSuchElementException();
    }
    return getNode(size -1).val;
}

那么我们能不能通过递归来实现呢?

public E getLast() {
    if (isEmpty()) {
        throw new NoSuchElementException();
    }
    return getLast(first);
}
​
// 返回 node 之后最后一个节点
private E getLast(Node<E> node){
    // base case
    if(node.next == null){
        //如果该节点的下一个节点为空,则其为尾节点
        return node.val;
    }
    //不满足base case 继续寻找下一个节点
    return getLast(node.next);
}

修改链表结构 removeLast remove addLast add

  1. 实现 removeLast

凡是修改链表结构的,都需要有一个返回值:

public void removeLast() {
    if (isEmpty()) {
        throw new NoSuchElementException();
    }
​
    first = removeLast(first);
    size--;
}
​
// x -> y -> null
private Node<E> removeLast(Node<E> node) {
    //base case
    if (node.next == null) {
        // node 就是最后一个节点,让自己直接消失
        return null;
    }
​
    node.next = removeLast(node.next);
    return node;
}

为什么这样写?

我们从后往前看,也就是找到最后一个节点时,我们返回 null ,然后被其上一个节点的 next 指针接收,一直返回,最后就只删除了最后一个节点。

2. 实现 remove(int index)

首先我们在 remove(int index),书写函数 remove(first, index),该函数表明删除 first 节点之后的第 index 个节点。

public void remove(int index) {
    checkElementIndex(index);
    first = remove(first, index);
    size--;
}

base case 结束条件,也就是 index == 0时找到了这个节点。那么我们应该返回什么呢?

  • 我们应该返回 node.next,也就是返回当前节点之后的节点,与当前节点之前的节点连接起来。
  • 然后 node.next = remove(node.next, index - 1) 接收返回的值。

image.png

//删除node之后的第 index 个节点
private Node<E> remove(Node<E> node, int index) {
    //base case
    if (index == 0) {
        return node.next;
    }
    node.next = remove(node.next, index - 1);
    return node;
}
  1. 实现 addLast
public void addLast(E e) {
    first = addLast(first, e);
    size++;
}

我们需要在最后插入一个节点,也就是当 node == null 时,返回新添加的节点,然后进行接收即可。

// a -> b -> c -> d -> e -> null
private Node<E> addLast(Node<E> node, E e) {
    //说明指向最后一个节点的next
    if (node == null) {
        return new Node<>(e);
    }
    //返回被接受,连接
    node.next = addLast(node.next, e);
    return node;
}
  1. 实现 add,在指定索引处插入节点
public void add(int index, E e) {
    checkPositionIndex(index);
    if (index == size) {
        addLast(e);
        return;
    }
    first = add(first, index, e);
    size++;
}

index == 0 时,说明找到了这个插入点,然后插入节点的next 指针需要指向当前插入节点(也就是连接插入点之后的节点),然后进行返回。与插入点前的节点连接。

// a -> b -> c -> null
private Node<E> add(Node<E> node, int index, E e) {
    if (index == 0) {
        //假设在b处插入节点
        Node<E> x = new Node<>(e);
        //需要把b节点之后的链表连接到x.next上
        x.next = node;
        //然后返回这段链表
        return x;
    }
    node.next = add(node.next, index - 1, e);
    return node;
}

单链表完整代码

3. 实现TreeMap

3.1. TreeMap原理及特性

TreeMap 底层是基于二叉查找树BST,有一个特点,根节点的值比左子树要大,比右子树要小。

为什么要使用BST?

  • 因为BST左小右大的特性,我们可以进行二分搜索。比如我们搜索2,先从根节点开始查找,2比6小,那么搜索左子树,否则搜索左子树,直到找到该节点。
  • 通常来说BST的查询效率很高,如果所有的节点都接到左子树上,那么就退化成单链表(不是自平衡的),那么搜索的效率就是O(n) 了。
  • Java的TreeMap底层使用的是红黑树,也就是自平衡的二叉搜索树。

我们基于什么来实现?

  • 我们基于 BST 实现,不考虑退化的情况,也就是不使用红黑树。

3.2. 初次实现

首先看初始代码

为什么 K extends Comparable<K> ?

  • 我们是根据 Key 的大小,将 key 存入到左右子树当中的。所以 key 必须是可比较的。
public class MyTreeMap<K extends Comparable<K>, V> {

    private class TreeNode {
        K key;
        V val;
        TreeNode left, right;
        
        TreeNode(K key, V val) {
            this.key = key;
            this.val = val;
            this.size = 1;
            left = right = null;
        }
    }

    private TreeNode root = null;

    public MyTreeMap() {
    }

    /***** 增/改 *****/

    // 添加 key -> val 键值对,如果键 key 已存在,则将值修改为 val
    public V put(K key, V val) {

    }


    /***** 删 *****/

    // 删除 key 并返回对应的 val
    public V remove(K key) {

    }


    // 删除并返回 BST 中最小的那个 key
    public void removeMin() {

    }


    // 删除并返回 BST 中最大的那个 key
    public void removeMax() {

    }

    /***** 查 *****/

    // 返回 key 对应的 val,如果 key 不存在,则返回 null
    public V get(K key) {

    }


    // 返回小于等于 key 的最大的键
    public K floorKey(K key) {

    }


    // 返回大于等于 key 的最小的键
    public K ceilingKey(K key) {

    }

    // 返回小于 key 的键的个数
    public int rank(K key) {
       
    }

    // 返回索引为 i 的键,i 从 0 开始计算
    public K select(int i) {
      
    }


    // 返回 BST 中最大的键
    public K maxKey() {
      
    }

    // 返回 BST 中最小的键
    public K minKey() {
      
    }

    // 判断 key 是否存在 Map 中
    public boolean containsKey(K key) {
       
    }

    /***** 工具函数 *****/

    public boolean isEmpty() {
        return size == 0;
    }

}

我们首先实现 get(K key) 方法

public V get(K key){
    //进行判断
    if(key == null){
        throw new IllegalArgumentException("key is null");
    }
    //定义get(root, key)方法
    TreeNode node = get(root, key);
    if(node == null){
        return null;
    }
    return node.val;
}

我们需要定义 get(TreeNode node, K key),从根节点开始遍历,找到指定 key 对应的节点。

  • 首先我们需要比较当前节点的 Key 与指定 Key 的大小。node.key.compareTo(key)
  • 如果 node == null ,说明没找到。
  • 如果当前根节点 key 小于指定 Key,那么就搜索右子树;如果大于就搜索左子树。
  • 剩下的情况就是找到了。
//获取node根节点下指定key对应的节点
private TreeNode get(TreeNode node, K key) {
    if(node == null){
        return null;
    }
    int i = node.key.compareTo(key);
    // i < 0 说明当前根节点小于 key
    if(i < 0){
        //说明在右子树
        return get(node.right, key);
    }
    if(i > 0){
        // i > 0 说明在左子树
        return get(node.left, key);
    }
    // 剩下的情况就是找到了
    return node;
}

然后我们实现 containsKey(K key) 方法,判断 key 是否存在。 我们可以直接使用 get(K key) 方法,判断其返回值是否为空即可。

public boolean containsKey(K key){
    if(key == null){
        throw new IllegalArgumentException("key is null");
    }
    //直接使用上面的方法
    TreeNode x = get(root, key);
    return x != null;
}

实现 put(K key, V val) 方法

有两种情况:key 存在修改或者 key 不存在插入。

public V put(K key, V val){
    if(key == null){
        throw new IllegalArgumentException("key is null");
    }
    V oldVal = get(key);
    //首先获取Key,如果oldVal不存在那么就相当于新增 size++
    if(oldVal == null){
        size++;
    }
    //为什么赋值给root 因为新增改变了二叉树
    root = put(root, key, val);
    return oldVal;
}

首先我们需要将当前节点 key 与插入 key 进行比较,如果 cmp > 0 说明在左子树,cmp < 0 说明在右子树,如果 cmp == 0 说明找到了,修改其值即可。 如果找不到也就是 node == null 时,直接返回新节点即可。因为 node.left 或者 node.right其一会进行接收。

private TreeNode put(TreeNode node, K key, V val) {
    //2.找不到,判断是否为空,如果为空,说明没有找到
    if(node == null){
        //为什么直接返回,当找不到的时候
        //只能在左子树或者右子树搜索
        //找不到返回值会被左右子树接收。
        return new TreeNode(key, val);
    }

    // 1.先找是否存在
    int cmp = node.key.compareTo(key);
    if(cmp > 0){
        node.left = put(node.left, key, val);
    }else if(cmp < 0){
        node.right = put(node.right, key, val);
    }else{
        // node.key == key 找到了
        node.val = val;
    }
    return node;
}

删除最小 Key 和最大 Key removeMin() removeMax()

image.png 我们观察这个树,可以发现最左子树就是最小值,最右子树就是最大值。 那么我们只需要判断 node.left == nullnode.right == null 就找到了最左子树和最右子树。

  1. 实现 removeMin() 删除最左子树。

首先我们需要从根节点查找到最左子树。

//删除最小值,也就是删除最左子树
public void removeMin() {
    if (isEmpty()) {
        throw new NoSuchElementException();
    }
    root = removeMin(root);
    size--;
}


node.left == null 说明已经走到最左侧节点,当前节点没有左子树,当前节点就是最左子树了。

可以返回 null 吗?

  • 不可以,如果当前节点有右子树,则会将当前节点和右子树一起删除。

image.png

  • 所以应该返回其右子树,让其覆盖当前节点即可,我们不用操心右子树是否为空。同时也维护了 BST 的性质。

image.png

private TreeNode removeMin(TreeNode node) {
    if(node.left == null){
        return node.right;
    }
    node.left = removeMin(node.left);
    return node;
}
  1. 实现 removeMax() 删除最右子树。

参照 removeMax实现可得

public void removeMax(){
    if (isEmpty()) {
        throw new NoSuchElementException();
    }
    root = removeMax(root);
    size--;
}

private TreeNode removeMax(TreeNode node) {
    if(node.right == null){
        return node.left;
    }
    node.right = removeMax(node.right);
    return node;
}

初步实现完整代码

3.3. TreeMap的删除

实现 remove(K key) 方法,删除指定 key 对应的节点。

// 删除 key 并返回对应的 val
public V remove(K key){
    // 检查参数有效性
    if(key == null){
        throw new IllegalArgumentException("key is null");
    }
    if(!containsKey(key)){
        return null;
    }
    
    // 获取旧值
    V deleteVal = get(key);
    
    // 删除节点
    root = remove(root, key);
    size--;
    return deleteVal;
}

然后我们定义 remove(TreeNode node, K key) 方法,表明删除当前 node 节点下对应二叉树中节点 key 等于指定 key

private TreeNode remove(TreeNode node, K key) {

}

我们应该怎么做?

  • 首先将当前节点 key 和需要删除的 key 进行比较。
  • 如果 node.key > key 则查询左子树。
  • 如果 node.key < key 则查询右子树。

那么相等的情况,我们应该怎么删除呢?会涉及三种情况?

  1. node 是叶子节点,左右子树都是 null
  2. node 左右子树有一个非空
  3. node 左右子树都不为空

node 是叶子节点,左右子树都是 null ,这个我们只需返回 null 即可。可以写如下的代码:

左右子树有一个非空怎么办呢?

image.png 假如我们要删除的节点是11,其左子树不为空,那么我们只需要返回其左子树即可,根据图示可以发现,BST的性质没有改变。 同理,如果删除节点的右子树不为空,那么只需要返回其右子树即可。我们可以写如下代码:

node 左右子树都不为空怎么办呢?

image.png 假如我们要删除的节点是根节点 6。那么我们应该怎么才能维持BST的性质呢?

  • 我们可以移动左子树的最大节点5至根节点,或者移动右子树的最小节点7至根节点来维持BST的性质。
  1. 我们可以移动左子树的最大节点5至根节点:

image.png 确实可以维护BST的性质。

疑问?

  • 难道删除节点的左子树的最大节点不能有左右子树吗,那样不是变成多于二叉的树了吗?

    • 首先最大节点肯定没有右子树,如果有那么其不是最大节点。
    • 那么它会有左子树吗,可能有左子树,那么我们应该怎么办呢?

image.png

我们可以把最大节点的左子树接到其父节点上,然后将从当前位置删除。记得保留当前节点的值,然后将该节点的左右指针指向删除节点指向的左右子树。这样就可以维护BST的性质了。

  1. 移动右子树的最小节点7至根节点来维持BST的性质。

image.png 疑问?

  • 我们右子树的最小节点,一定没有左子树,但是其可能有右子树,那么我们应该怎么办呢?

image.png 我们只需要把右子树接到最小节点的父节点上,然后将最小节点从当前位置删除,然后保留最小节点的值,然后将其左右指针指向删除节点的左右子树。这样就维护了BST的性质。

完整代码实现

  1. 首先我们需要实现两个方法,找到以某个节点为根节点的BST的最大节点和最小节点:
// 以当前节点为根节点的BST的最大节点
private TreeNode maxNode(TreeNode p) {
    while (p.right != null) {
        p = p.right;
    }
    return p;
}

// 以当前节点为根节点的BST的最小节点
private TreeNode minNode(TreeNode p) {
    while (p.left != null) {
        p = p.left;
    }
    return p;
}
  1. 然后实现我们的方法:前面的章节我们已经实现了 removeMinremoveMax 方法了。
  private TreeNode remove(TreeNode node, K key) {
        // 进行比较
        int cmp = node.key.compareTo(key);
        if(cmp > 0){
            // node.key > key 去左子树找
            node.left = remove(node.left, key);
        }else if(cmp < 0){
            // node.key < key 去右子树找
            node.right = remove(node.right, key);
        }else{
            if(node.left == null && node.right == null){
                // 左右子树都为空
                return null;
            }else if(node.left != null && node.right == null){
                //左子树不为空
                return node.left;
            }else if(node.left == null && node.right != null){
                // 右子树不为空
                return node.right;
            }
            // 剩下的情况的就是左右子树不为空的情况了。
            // 我们有两种方案实现
            // 1. 找到当前节点的前驱节点,也就是左子树的最大值
            //首先我们需要找到这个最大节点
            TreeNode leftMax = maxNode(node.left);
            //然后通过我们的 removeMax删除这个节点,让其父节点的左指针指向这个删除这个节点后的左子树
            node.left = removeMax(node.left);
            leftMax.left = node.left;
            leftMax.right = node.right;
//            // 2. 找到当前节点的后继节点,也就是右子树的最小值
//            TreeNode rightMin = minNode(node.right);
//            node.right = removeMin(node.right);
//            rightMin.left = node.left;
//            rightMin.right = node.right;
        }
        return node;
    }

3.4. 实现floorKey和ceilingKey方法

// 返回小于等于 key 的最大的键
public K floorKey(K key) {
    
}

// 返回大于等于 key 的最小的键
public K ceilingKey(K key) {

}

首先我们实现 floorKey 方法:

public K floorKey(K key) {
    // 参数检查
    if (key == null) {
        throw new IllegalArgumentException("key is null");
    }
    if (isEmpty()) {
        throw new NoSuchElementException();
    }
    
    //创建辅助函数
    TreeNode x = floorKey(root, key);
    return x.key;
}

注意:我们要查找的是小于等于key的最大的键,如果key存在于BST中,那么就返回key。我们只需要考虑key不存在的情况。 image.png 比如我们要查找小于等于16的最大的键,我们一直向下查找,发现找不到,那么我们返回其父节点就可以了。

实现:逻辑类似于 get 方法,只需要修改右子树处的相关代码即可。

private TreeNode floorKey(TreeNode node, K key) {
    if(node == null){
        return null;
    }
    int i = node.key.compareTo(key);
    // i < 0 说明当前根节点小于 key
    if(i < 0){
        //说明在右子树
        TreeNode x =  floorKey(node.right, key);
        // 当发现在右子树查找时,返回 null 说明找不到,那么返回其父节点即可,也就是返回当前节点 node。
        if(x == null){
            return node;
        }
    }
    if(i > 0){
        // i > 0 说明在左子树
        return floorKey(node.left, key);
    }
    // 剩下的情况就是找到了
    return node;
}

同理实现:ceilingKey

// 返回大于等于 key 的最小的键
public K ceilingKey(K key) {
    if (key == null) {
        throw new IllegalArgumentException("key is null");
    }
    if (isEmpty()) {
        throw new NoSuchElementException();
    }

    TreeNode x = ceilingKey(root, key);
    return x.key;
}

private TreeNode ceilingKey(TreeNode node, K key) {
    if(node == null){
        return null;
    }
    int i = node.key.compareTo(key);
    // i < 0 说明当前根节点小于 key
    if(i < 0){
        //说明在右子树
        return ceilingKey(node.right, key);
    }
    if(i > 0){
        // i > 0 说明在左子树
        TreeNode x = ceilingKey(node.left, key);
        if(x == null){
            return node;
        }
        return x;
    }
    // 剩下的情况就是找到了
    return node;
}

3.5. 实现 keys 相关方法

// 从小到大返回所有键  
public Iterable<K> keys() {  

}  

// 从小到大返回闭区间 [min, max] 中的键  
public Iterable<K> keys(K min, K max) {  

}

首先我们实现从小到大返回所有键,根据BST的特性,那么通过中序遍历(左 根 右)BST,就可以返回从小到大的键了。

// 从小到大返回所有键
public Iterable<K> keys() {
    if (isEmpty()) {
        return new LinkedList<>();
    }
    LinkedList<K> list = new LinkedList<>();
    traverse(root, list);
    return list;
}

// 中序遍历 BST
private void traverse(TreeNode node, LinkedList<K> list) {
    if (node == null) {
        return;
    }
    // 先遍历left
    traverse(node.left, list);
    // 中序遍历
    list.addLast(node.key);
    // 再遍历 right
    traverse(node.right, list);
}

实现从小到大返回闭区间 [min, max] 中的键,我们只需要在中序遍历的时候添加值得时候,判断当前节点是否在 [min, max] 范围内。

// 从小到大返回闭区间 [min, max] 中的键
public Iterable<K> keys(K min, K max) {
    if (min == null) throw new IllegalArgumentException("min is null");
    if (max == null) throw new IllegalArgumentException("max is null");

    LinkedList<K> list = new LinkedList<>();
    traverse(root, list, min, max);
    return list;
}

// 中序遍历 BST
private void traverse(TreeNode node, LinkedList<K> list, K min, K max) {
    if (node == null) {
        return;
    }

    int cmpMin = min.compareTo(node.key);
    int cmpMax = max.compareTo(node.key);

    traverse(node.left, list);

    // 中序遍历 min <= node.key <= max
    if (cmpMin <= 0 && cmpMax >= 0) {
        list.addLast(node.key);
    }

    traverse(node.right, list);
    
}

但是这样的效率很低?

  • 因为我们没有BST的性质。

  • min 大于等于当前 node 的时候,那么我们就不需要遍历当前 node 的左子树了,因为左子树的节点都小于 node,自然也小于 min

  • max 小于等于当前 node 的时候,那么我们就不需要遍历当前 node 的右子树了,因为右子树的节点都大于 node,自然也大于 max

优化后的代码:

private void traverse(TreeNode node, LinkedList<K> list, K min, K max) {
    if (node == null) {
        return;
    }

    int cmpMin = min.compareTo(node.key);
    int cmpMax = max.compareTo(node.key);

    if (cmpMin < 0) {
        // min < node.key 才进行遍历
        traverse(node.left, list);
    }

    // 中序遍历 min <= node.key <= max
    if (cmpMin <= 0 && cmpMax >= 0) {
        list.addLast(node.key);
    }

    if (cmpMax > 0) {
        // max > node.key 才进行遍历
        traverse(node.right, list);
    }
}

3.6. 实现 select 和 rank 方法

// 返回小于 key 的键的个数
public int rank(K key) {
   
}
// 返回索引为 i 的键,i 从 0 开始计算
public K select(int i) {

}

实现 rank 方法,根据BST的性质,小于 key 的键的个数,就是统计其左子树节点的个数。

我们可以通过前序遍历、中序遍历、后序遍历来统计节点个数,但是这样的时间效率很高。

有没有一种方法可以快速统计节点的个数呢?

  • 我们可以在 TreeNode 类中添加 size 属性,用来记录以当前节点为根的 BST 有多少个节点。然后移除外部的 size 属性。

image.png

  • 然后需要修改相应的工具函数:

image.png

  • 然后修改改变BST结构的方法,put, removeXxx

image.png 同上图改造其他方法即可。

正式实现 rank 方法

// 返回小于 key 的键的个数
public int rank(K key) {
    if (key == null) {
        throw new IllegalArgumentException();
    }
    return rank(root, key);
}

创建辅助函数:

  1. 首先我们需要比较当前节点 key 与指定 key 的大小。
  2. 如果 key < node.key,根据 BST 的性质,说明 nodenode.right 都大于 key 。我们只需要查找左子树就好了。
  3. 如果 key > node.key,根据 BST 的性质,说明 nodenode.left 都是比 key 小的,我们需要返回左子树的个数和当前节点个数1,然后查找右子树返回其个数即可。
  4. 如果相等,说明 node 节点的左子树满足,直接返回个数即可。
// 返回以 node 为根的 BST 中小于 key 的键的个数
private int rank(TreeNode node, K key) {
    int cmp = key.compareTo(node.key);
    
    if (cmp < 0) {
        // key < node.key
        // 和 node 以及 node.right 没啥关系了
        // 因为它们太大了
        return rank(node.left, key);
    } else if (cmp > 0) {
        // key > node.key
        // node 和 node.left 左子树都是比 key 小的
        return size(node.left) + 1 + rank(node.right, key);
    } else {
        // key == node.key
        return size(node.left);
    }
}

实现 select 方法,返回索引为 i 的键,i0 开始计算

返回索引为 i 的键,是什么意思呢?

  • 就是我们通过中序遍历,从小到大排序的节点。索引 0 就是返回最小的元素。
// 返回索引为 i 的键,i 从 0 开始计算
public K select(int i) {
    if (i < 0 || i >= size()) {
        throw new IllegalArgumentException();
    }

    TreeNode x = select(root, i);
    return x.key;
}

定义辅助函数:

  1. 首先我们需要计算出当前节点的索引,那么如何计算呢,通过 size(node.left) 方法计算,因为根据BST的性质,左小右大,其左子树节点个数就是其索引。

image.png 2. 然后比较当前节点索引 ni 的大小:

  • 如果 n > i,那么我们只能在当前节点左子树查找,因为当前节点右子树索引大于 n 也大于 i
  • 如果 n < i,那么我们只能在当前节点右子树查找,因为当前节点左子树索引都小于 n
    • 为什么是 i - n -1
    • 如果我们需要查找索引为10的节点,当前节点索引为 5,那么我们肯定要查找右子树。如果索引还是传入i=10,那么下一轮重新计算 node 的索引,就是 0 似乎没什么问题;如果计算下一轮,那么索引变为了3,i还是10,总数是4个,在这个小的子树中,肯定找不到索引为10的节点。
    • i - n - 1 => 10 - 5 - 1 = 4 计算出其在右子树的索引,然后来查找。

image.png

  • 如果 n = i,说明找到了直接返回即可。
// 返回以 node 为根的 BST 中索引为 i 的那个节点
private TreeNode select(TreeNode node, int i) {
    int n = size(node.left);

    if (n > i) {
        // n == 10, i == 3
        return select(node.left, i);
    } else if (n < i) {
        // n == 3, i == 10
        return select(node.right, i - n - 1);
    } else {
        // i == n
        // node 就是索引为 i 的那个节点
        return node;
    }
}

TreeMap 完整代码