遵守 TDD 实现一个精简版的 HashMap

1,658 阅读8分钟

前言

上一篇文章笔者解读了 HashMap 的源码,正好趁热打铁,今天笔者抽了些时间通过 TDD 实现了一个精简版的 HashMap,经笔者测试,正常情况下效率略微逊于 HashMap。

预设计

public class SimpleHashMap<K, V> {
    public V put(K key, V value);   
    public V get(K key);   
    public V remove(K key); 
    public boolean containsKey(K key); 
    public int size();
    public Iterator<V> values();
    public void forEach(Consumer<? super K> action);
}

Tasking

  • 无参构建 SimpleHashMap
  • 构造函数初始化 initial capacity
  • 构造函数初始化 initial capacity 和 load factor
  • initial capacity 默认使用 16
  • load factor 默认使用 0.75f
  • 初始化的 resize 门槛为 initial capacity * load factor
  • 再次 resize 门槛为 threshold = threshold << 1
  • 增加 put 接口
    • 计算 hash 值
    • 增加 hash table 用于保存数据节点.
    • 如果 hash table 的容量为 0 或者 hash table 的容量超过门槛,则设置新的 resize 门槛,并扩容和 rehash。
    • hash table 的下标为 hash & (capacity -1)
    • 扩容时需要把旧的 hash table 的数据转移到新的 hash table
    • 转移数据到新的 hash table 之前需要 rehash,rehash = entry.hash & (new_capacity -1)
    • 如果 hash 冲突,使用链表存储
    • 如果同一个 hash 冲突超过 8 次,使用红黑树存储
  • 增加 size 接口
    • 增加全局的 size 成员变量.
    • put 接口调用成功,则 size += 1.
    • remove 接口调用成功,则 size -= 1.
    • 考虑链表
    • 考虑红黑树
  • 增加 containsKey 接口
    • 通过 key 计算 hash
    • 通过 hash 计算 index
    • 通过 index 检索 key,检索到return true,否则 return false,
    • 考虑 hash table 为 null.
    • 考虑链表
    • 考虑红黑树
  • 增加 get 接口
    • 通过 key 计算 hash
    • 通过 hash 计算 index
    • 通过 index 检索 bucket
    • 如果 bucket 存在多个数据节点,则需要判断 key 的值和引用是否相等.
    • 如果相等返回对应的 value,否则返回 null.
    • 考虑链表
    • 考虑红黑树
  • 增加 remove 接口
    • 通过 key 计算 hash
    • 通过 hash 计算 index
    • 通过 index 检索 bucket
    • 如果相等则将对应的 bucket 置 null,并返回对应的 value,否则返回 null,
    • 考虑链表
    • 考虑红黑树
  • 增加 values 接口
    • 每次 put 成功时保存 list 中到
    • 每次 put 替换成功时,需要替换 list 中对应的 value
    • 每次 remove 成功时从 list 中到删除
    • 考虑链表
    • 考虑红黑树
  • 增加 forEach 接口
    • 遍历 hash table
    • 如果存在 bucket,则通过 action.apply(key)
    • 考虑链表
    • 考虑红黑树
  • 增加 fail-fast
    • 增加 modCount 成员变量用于统计变更次数
    • 迭代前后需要验证 modCount 前后是否一致
    • 如果 modCount 前后是否一致需要抛出 ConcurrentModificationException.
  • 增加 rb tree 保存 hash 冲突超过 8 次的数据节点.

测试覆盖率

测试代码

**
 * @author lyning
 */
public class SimpleHashMapTest {

    private SimpleHashMap<Integer, Integer> map;

    @BeforeEach
    public void setUp() throws Exception {
        // given
        this.map = new SimpleHashMap<>();
    }

    /************ size test start **********/
    @Test
    @DisplayName("given empty entries" +
            "when call size() " +
            "then return 0")
    public void size1() {
        // when
        int size = map.size();
        // then
        assertThat(size).isZero();
    }

    @Test
    @DisplayName("given multiple entries(contains duplicate key) " +
            "when call size() " +
            "then return correct size")
    public void size2() {
        // given
        SimpleHashMap<Integer, Integer> map = new SimpleHashMap<>();
        map.put(1, 1);
        map.put(2, 2);
        map.put(3, 3);
        map.put(3, 4);
        map.put(3, 5);
        map.put(4, 4);
        map.put(5, 5);
        map.remove(1);
        map.remove(2);
        // when
        int size = map.size();
        // then
        assertThat(size).isEqualTo(3);
    }

    @Test
    @DisplayName("given multiple entries(hash conflict) " +
            "when call size() " +
            "then return correct size")
    public void size3() {
        // given
        SimpleHashMap<HashConflict, Integer> map = new SimpleHashMap<>();
        map.put(new HashConflict(1), 1);
        map.put(new HashConflict(2), 2);
        map.put(new HashConflict(3), 3);
        map.put(new HashConflict(4), 4);
        map.put(new HashConflict(5), 5);
        map.remove(new HashConflict(5));
        map.remove(new HashConflict(3));
        // when
        int size = map.size();
        // then
        assertThat(size).isEqualTo(3);
    }
    /************ size test end **********/


    /************ put test start **********/
    @Test
    @DisplayName("given empty entries " +
            "when put one entry " +
            "then return size 1")
    public void put1() {
        // when
        map.put(1, 1);
        // then
        assertThat(map.size()).isOne();
    }

    @Test
    @DisplayName("given empty entries " +
            "when put two entries(duplicate key) " +
            "then return size 1")
    public void put2() {
        // when
        map.put(1, 1);
        map.put(1, 2);
        // then
        assertThat(map.size()).isEqualTo(1);
    }

    @Test
    @DisplayName("given empty entries " +
            "when put three entries " +
            "then return size 3")
    public void put3() {
        // when
        map.put(1, 1);
        map.put(2, 2);
        map.put(3, 3);
        // then
        assertThat(map.size()).isEqualTo(3);
    }

    @Test
    @DisplayName("should return value " +
            "when call put")
    public void put4() {
        // when
        Integer value = map.put(1, 1);
        // then
        assertThat(value).isEqualTo(1);
    }

    @Test
    @DisplayName("given empty entries " +
            "when put multiples entries(hash conflict) " +
            "then")
    public void put5() {
        // given
        SimpleHashMap<HashConflict, Integer> map = new SimpleHashMap<>();
        // when
        map.put(new HashConflict(1), 1);
        map.put(new HashConflict(2), 2);
        map.put(new HashConflict(3), 3);
        map.put(new HashConflict(3), 4);
        map.put(new HashConflict(3), 5);
        map.put(new HashConflict(4), 4);
        map.put(new HashConflict(5), 5);
        // then
        assertThat(Lists.newArrayList(map.values())).isEqualTo(Lists.list(1, 2, 5, 4, 5));
    }

    @Test
    @DisplayName("should auto grow " +
            "when capacity exceed threshold")
    public void put6() {
        // given default threshold = 8
        // when
        for (int i = 1; i <= 20; i++) {
            map.put(i, i);
        }
        // then
        assertThat(map.size()).isEqualTo(20);
        assertThat(map.get(20)).isEqualTo(20);
    }
    /************ put test end **********/

    /************ get test start **********/
    @Test
    @DisplayName("given empty entries" +
            "when get by null key" +
            "then return null")
    public void get1() {
        // when
        Integer value = map.get(null);
        // then
        assertThat(value).isNull();
    }

    @Test
    @DisplayName("given empty entries" +
            "when get value by not exist key" +
            "then return null")
    public void get2() {
        // when
        Integer value = map.get(2);
        // then
        assertThat(value).isNull();
    }

    @Test
    @DisplayName("given entry" +
            "when get value by not exist key" +
            "then return null")
    public void get3() {
        // given
        map.put(1, 1);
        // when
        Integer value = map.get(2);
        // then
        assertThat(value).isNull();
    }

    @Test
    @DisplayName("given entry" +
            "when get value" +
            "then return value")
    public void get4() {
        // given
        map.put(1, 1);
        // when
        Integer value = map.get(1);
        // then
        assertThat(value).isEqualTo(1);
    }

    @Test
    @DisplayName("given multiple entries(hash conflict)" +
            "when get value by hash conflict key" +
            "then return value")
    public void get5() {
        // given
        SimpleHashMap<HashConflict, Integer> map = new SimpleHashMap<>();
        map.put(new HashConflict(1), 1);
        map.put(new HashConflict(2), 2);
        map.put(new HashConflict(3), 3);
        map.put(new HashConflict(3), 4);
        map.put(new HashConflict(3), 5);
        map.put(new HashConflict(4), 4);
        map.put(new HashConflict(5), 5);
        // when
        Integer value = map.get(new HashConflict(3));
        // then
        assertThat(value).isEqualTo(5);
    }

    @Test
    @DisplayName("given multiple entries(hash conflict)" +
            "when get value by not exist hash conflict key" +
            "then return null")
    public void get6() {
        // given
        SimpleHashMap<HashConflict, Integer> map = new SimpleHashMap<>();
        map.put(new HashConflict(1), 1);
        map.put(new HashConflict(2), 2);
        map.put(new HashConflict(3), 3);
        map.put(new HashConflict(4), 4);
        map.put(new HashConflict(5), 5);
        // when
        Integer value = map.get(new HashConflict(6));
        // then
        assertThat(value).isNull();
    }
    /************ get test end **********/


    /************ remove test start **********/
    @Test
    @DisplayName("given empty entries" +
            "when remove by null key" +
            "then return null")
    public void remove1() {
        // when
        Integer value = map.remove(null);
        // then
        assertThat(value).isNull();
    }

    @Test
    @DisplayName("given entry" +
            "when remove by null key" +
            "then return null")
    public void remove2() {
        // given
        map.put(1, 1);
        // when
        Integer value = map.remove(null);
        // then
        assertThat(value).isNull();
    }

    @Test
    @DisplayName("given entry" +
            "when remove by key" +
            "then return value")
    public void remove3() {
        // given
        map.put(1, 1);
        // when
        int value = map.remove(1);
        // then
        assertThat(value).isEqualTo(1);
    }

    @Test
    @DisplayName("given entry" +
            "when remove by not exist key" +
            "then return null")
    public void remove4() {
        // given
        map.put(1, 1);
        // when
        Integer value = map.remove(2);
        // then
        assertThat(value).isNull();
    }

    @Test
    @DisplayName("given multiple entries(hash conflict)" +
            "when remove by hash conflict key" +
            "then return value")
    public void remove5() {
        // given
        SimpleHashMap<HashConflict, Integer> map = new SimpleHashMap<>();
        map.put(new HashConflict(1), 1);
        map.put(new HashConflict(2), 2);
        map.put(new HashConflict(3), 3);
        map.put(new HashConflict(4), 4);
        map.put(new HashConflict(5), 5);
        // when
        Integer value = map.remove(new HashConflict(3));
        // then
        assertThat(value).isEqualTo(3);
        assertThat(Lists.newArrayList(map.values())).isEqualTo(Lists.list(1, 2, 4, 5));
    }
    /************ remove test end **********/


    /************ values test start **********/
    @Test
    @DisplayName("given empty entries" +
            "when call values" +
            "then return empty values")
    public void values1() {
        // when
        Iterable<Integer> values = map.values();
        // then
        assertThat(values).isEmpty();
    }

    @Test
    @DisplayName("given multiple entries" +
            "when call values" +
            "then return all values")
    public void values2() {
        // given
        map.put(1, 1);
        map.put(2, 2);
        map.put(3, 3);
        map.put(3, 4);
        map.put(4, 4);
        map.remove(4);
        // when
        Iterable<Integer> values = map.values();
        // then
        assertThat(values.spliterator().estimateSize()).isEqualTo(3);
        assertThat(Lists.newArrayList(values)).isEqualTo(Lists.list(1, 2, 4));
    }
    /************ values test end **********/


    /************ containsKey test start **********/
    @Test
    @DisplayName("given entry" +
            "when key exist" +
            "then return true")
    public void contains_key1() {
        // given
        map.put(1, 1);
        // when
        boolean result = map.containsKey(1);
        // then
        assertThat(result).isTrue();
    }

    @Test
    @DisplayName("given entry" +
            "when key not exist" +
            "then return false")
    public void containsKey2() {
        // given
        map.put(1, 1);
        // when
        boolean result = map.containsKey(2);
        // then
        assertThat(result).isFalse();
    }

    @Test
    @DisplayName("given multiple entries(hash conflict)" +
            "when call containsKey" +
            "then return correct result")
    public void containsKey3() {
        // given
        SimpleHashMap<HashConflict, Integer> map = new SimpleHashMap<>();
        map.put(new HashConflict(1), 1);
        map.put(new HashConflict(2), 2);
        map.put(new HashConflict(3), 3);
        map.put(new HashConflict(4), 4);
        map.put(new HashConflict(5), 5);
        // then
        assertThat(map.containsKey(new HashConflict(3))).isTrue();
        assertThat(map.containsKey(new HashConflict(5))).isTrue();
        assertThat(map.containsKey(new HashConflict(6))).isFalse();
    }
    /************ containsKey test end **********/


    /************ forEach test start **********/
    @Test
    @DisplayName("given multiple entries" +
            "when call forEach" +
            "then pass")
    public void forEach1() {
        // given
        map.put(1, 1);
        map.put(2, 2);
        map.put(3, 3);
        map.put(4, 4);
        // when
        List<Integer> results = new ArrayList<>();
        map.forEach((key) -> results.add(map.get(key)));
        // then
        assertThat(results).isEqualTo(Lists.list(1, 2, 3, 4));
    }

    @Test
    @DisplayName("given multiple entries(hash conflict)" +
            "when call forEach" +
            "then pass")
    public void forEach2() {
        // given
        SimpleHashMap<HashConflict, Integer> map = new SimpleHashMap<>();
        map.put(new HashConflict(1), 1);
        map.put(new HashConflict(2), 2);
        map.put(new HashConflict(3), 3);
        map.put(new HashConflict(4), 4);
        map.put(new HashConflict(5), 5);
        // when
        List<Integer> results = new ArrayList<>();
        map.forEach((key) -> results.add(map.get(key)));
        // then
        assertThat(results).isEqualTo(Lists.list(1, 2, 3, 4, 5));
    }

    /************ forEach test end **********/

    class HashConflict {
        private int field;

        HashConflict(int field) {
            this.field = field;
        }

        @Override
        public int hashCode() {
            return this.field <= 8 ? 1 : this.field;
        }

        @Override
        public boolean equals(Object obj) {
            return ((HashConflict) obj).field == this.field;
        }
    }
}

SimpleHashMap 源码

/**
 * @author lyning
 */
public class SimpleHashMap<K, V> {
    private static final int DEFAULT_INITIAL_CAPACITY = 16;
    private static final float DEFAULT_LOAD_FACTOR = 0.75f;
    private int size;
    private Bucket<K, V>[] table;
    private int threshold;

    public boolean containsKey(K key) {
        int hash = this.hash(key);
        int index = this.index(hash);
        Bucket<K, V> bucket = this.table[index];
        return bucket != null
                && bucket.lookup(key) != null;
    }

    public void forEach(Consumer<K> action) {
        for (Bucket<K, V> bucket : this.table) {
            while (bucket != null) {
                action.accept(bucket.key);
                bucket = bucket.next;
            }
        }
    }

    public V get(K key) {
        if (this.tableEmpty()) {
            return null;
        }
        int hash = this.hash(key);
        int index = this.index(hash);
        return this.getVal(index, key);
    }

    public V put(K key, V value) {
        if (this.tableEmpty() || this.nearByThreshold()) {
            this.resize();
        }
        int hash = this.hash(key);
        return this.putVal(key, value, hash);
    }

    public V remove(K key) {
        if (this.tableEmpty()) {
            return null;
        }
        int hash = this.hash(key);
        int index = this.index(hash);
        return this.removeVal(index, key);
    }

    public int size() {
        return this.size;
    }

    public Iterable<V> values() {
        if (this.tableEmpty()) {
            return new ArrayList<>();
        }
        List<V> collections = new ArrayList<>();
        this.collectValues(collections);
        return collections;
    }

    private void collectValues(List<V> collections) {
        for (Bucket<K, V> bucket : this.table) {
            while (bucket != null) {
                collections.add(bucket.value);
                bucket = bucket.next;
            }
        }
    }

    private Bucket<K, V> findBucket(int index) {
        return this.table[index];
    }

    private V getVal(int index, K key) {
        Bucket<K, V> bucket = this.findBucket(index);
        if (Objects.isNull(bucket) || Objects.isNull(bucket = bucket.lookup(key))) {
            return null;
        }
        return bucket.value;
    }

    private void grow(int newCap) {
        if (this.tableEmpty()) {
            this.initTable(newCap);
            return;
        }
        this.table = this.rebuildTable(newCap);
    }

    private int hash(K key) {
        int hashcode;
        return key == null
                ? 0
                : (hashcode = key.hashCode()) ^ (hashcode >>> 16);
    }

    private int index(int hash) {
        return hash & (this.table.length - 1);
    }

    private void initTable(int newCap) {
        this.table = new Bucket[newCap];
    }

    private boolean nearByThreshold() {
        return this.size + 1 >= this.threshold;
    }

    private V putVal(K key, V value, int hash) {
        int index = this.index(hash);
        Bucket<K, V> bucket = this.table[index];

        if (Objects.isNull(bucket)) {
            this.table[index] = new Bucket<>(hash, key, value);
        } else {
            Bucket<K, V> indexBucket = bucket.lookup(key);
            if (indexBucket != null) {
                indexBucket.value = value;
                return value;
            }
            bucket.putLast(new Bucket<>(hash, key, value));
        }
        this.size += 1;
        return value;
    }

    private Bucket<K, V>[] rebuildTable(int newCap) {
        Bucket<K, V>[] oldTable = this.table;
        Bucket<K, V>[] newTable = new Bucket[newCap];
        for (Bucket<K, V> bucket : oldTable) {
            if (bucket != null) {
                int index = this.index(bucket.hash);
                newTable[index] = bucket;
            }
        }
        return newTable;
    }

    private V removeVal(int index, K key) {
        Bucket<K, V> bucket = this.findBucket(index);
        Bucket<K, V> prev = null;
        while (bucket != null) {
            if (bucket.matchKey(key)) {
                if (Objects.isNull(prev)) {
                    this.table[index] = null;
                } else {
                    prev.next = bucket.next;
                }
                this.size -= 1;
                return bucket.value;
            }
            prev = bucket;
            bucket = bucket.next;
        }
        return null;
    }

    private void resize() {
        int oldCap = this.tableCapacity();
        int newCap = 0;
        if (oldCap == 0) {
            oldCap = DEFAULT_INITIAL_CAPACITY;
            this.threshold = (int) (DEFAULT_INITIAL_CAPACITY * DEFAULT_LOAD_FACTOR);
        } else {
            newCap = oldCap << 1;
            this.threshold = this.threshold << 1;
        }

        if (newCap == 0) {
            newCap = oldCap;
        }
        this.grow(newCap);
    }

    private int tableCapacity() {
        return Objects.isNull(this.table) ? 0 : this.table.length;
    }

    private boolean tableEmpty() {
        return Objects.isNull(this.table);
    }

    static class Bucket<K, V> {
        Bucket<K, V> next;
        int hash;
        K key;
        V value;

        public Bucket(int hash, K key, V value) {
            this.hash = hash;
            this.key = key;
            this.value = value;
        }

        public Bucket<K, V> lookup(K key) {
            Bucket<K, V> bucket = this;
            while (bucket != null) {
                if (bucket.matchKey(key)) {
                    return bucket;
                }
                bucket = bucket.next;
            }
            return null;
        }

        public boolean matchKey(K key) {
            return this.key == key || this.key.equals(key);
        }

        public void putLast(Bucket<K, V> bucket) {
            this.last().next = bucket;
        }

        private Bucket last() {
            Bucket<K, V> bucket = this;
            while (true) {
                if (Objects.isNull(bucket.next)) {
                    return bucket;
                }
                bucket = bucket.next;
            }
        }
    }
}

总结

其中最难的应属红黑树,真的是极其复杂,笔者用了一个小时还没能理解其中要领,索性使用链表替代了,等有时间再静下心来把未完成的任务消灭掉。

理解问题,Tasking,TDD(包含重构),这是笔者最近一直在遵守的规则,希望可以给您给来一点感悟。

源码