ThreadLocal源码深度剖析

7,171 阅读12分钟

ThreadLocal的作用

ThreadLocal的作用是提供线程内的局部变量,说白了,就是在各线程内部创建一个变量的副本,相比于使用各种锁机制访问变量,ThreadLocal的思想就是用空间换时间,使各线程都能访问属于自己这一份的变量副本,变量值不互相干扰,减少同一个线程内的多个函数或者组件之间一些公共变量传递的复杂度。

ThreadLocal的使用

查看官方文档,可知ThreadLocal包含以下方法:

其中get函数用来获取与当前线程关联的ThreadLocal的值,如果当前线程没有该ThreadLocal的值,则调用initialValue函数获取初始值返回,initialValue是protected类型的,所以一般我们使用时需要继承该函数,给出初始值。而set函数是用来设置当前线程的该ThreadLocal的值,remove函数用来删除ThreadLocal绑定的值,在某些情况下需要手动调用,防止内存泄露。

代码示例

public class Main {
    private static final ThreadLocal<Integer> threadLocal = new ThreadLocal<Integer>(){
        @Override
        protected Integer initialValue() {
            return Integer.valueOf(0);
        }
    };

    public static void main(String[] args) {
        Thread[] threads = new Thread[5];
        for (int i=0;i<5;i++) {
            threads[i] = new Thread(new Runnable() {
                @Override
                public void run() {
                    System.out.println(Thread.currentThread() +"'s initial value: " + threadLocal.get());
                    for (int j=0;j<10;j++) {
                        threadLocal.set(threadLocal.get() + j);
                    }
                    System.out.println(Thread.currentThread() +"'s last value: " + threadLocal.get());
                }
            });
        }

        for (Thread t: threads)
            t.start();
    }
}

上面的例子,threadLocal是每线程独有的,所以就算进行累加后,彼此的值也是互不影响的,最后输出如下:

ThreadLocal源码剖析

看了ThreadLocal的基本使用,读者一定想知道ThreadLocal内部是如何实现的?本文主要的内容就是基于Java1.8的代码带你剖析ThreadLocal的内部实现。首先看下面这幅图:

我们可以看出每个Thread维护一个ThreadLocalMap,存储在ThreadLocalMap内的就是一个以Entry为元素的table数组,Entry就是一个key-value结构,key为ThreadLocal,value为存储的值。类比HashMap的实现,其实就是每个线程借助于一个哈希表,存储线程独立的值。我们可以看看Entry的定义:

static class Entry extends WeakReference<ThreadLocal<?>> {
    /** The value associated with this ThreadLocal. */
    Object value;

    Entry(ThreadLocal<?> k, Object v) {
        super(k);
        value = v;
    }
}

这里ThreadLocal和key之间的线是虚线,因为Entry是继承了WeakReference实现的,当ThreadLocal Ref销毁时,指向堆中ThreadLocal实例的唯一一条强引用消失了,只有Entry有一条指向ThreadLocal实例的弱引用,假设你知道弱引用的特性,那么这里ThreadLocal实例是可以被GC掉的。这时Entry里的key为null了,那么直到线程结束前,Entry中的value都是无法回收的,这里可能产生内存泄露,后面会说如何解决。

知道大概的数据结构后,我们来探究一下ThreadLocal具体几个方法的实现:

get方法

public T get() {
        Thread t = Thread.currentThread();
        ThreadLocalMap map = getMap(t);
        if (map != null) {
            ThreadLocalMap.Entry e = map.getEntry(this);
            if (e != null) {
                @SuppressWarnings("unchecked")
                T result = (T)e.value;
                return result;
            }
        }
        return setInitialValue();
    }

直接看代码,可以分析主要有以下几步:

  1. 获取当前的Thread对象,通过getMap获取Thread内的ThreadLocalMap
  2. 如果map已经存在,以当前的ThreadLocal为键,获取Entry对象,并从从Entry中取出值
  3. 否则,调用setInitialValue进行初始化。

下面再看上面具体提到的几个函数:

getMap

 ThreadLocalMap getMap(Thread t) {
        return t.threadLocals;
    }

getMap很简单,就是返回线程中ThreadLocalMap,跳到Thread源码里看,ThreadLocalMap是这么定义的:

ThreadLocal.ThreadLocalMap threadLocals = null;

所以ThreadLocalMap还是定义在ThreadLocal里面的,我们前面已经说过ThreadLocalMap中的Entry定义,下面为了先介绍ThreadLocalMap的定义我们把setInitialValue放在前面说。

setInitialValue

private T setInitialValue() {
    T value = initialValue();
    Thread t = Thread.currentThread();
    ThreadLocalMap map = getMap(t);
    if (map != null)
        map.set(this, value);
    else
        createMap(t, value);
    return value;
}

setInititialValue在Map不存在的时候调用

  1. 首先是调用initialValue生成一个初始的value值,深入initialValue函数,我们可知它就是返回一个null;
  2. 然后还是在get以下Map,如果map存在,则直接map.set,这个函数会放在后文说;
  3. 如果不存在则会调用createMap创建ThreadLocalMap,这里正好可以先说明下ThreadLocalMap了。

ThreadLocalMap

createMap方法的定义很简单:

void createMap(Thread t, T firstValue) {
    t.threadLocals = new ThreadLocalMap(this, firstValue);
}

就是调用ThreadLocalMap的构造函数生成一个map,下面我们看看ThreadLocalMap的定义:

static class ThreadLocalMap {
    static class Entry extends WeakReference<ThreadLocal<?>> {
        /** The value associated with this ThreadLocal. */
        Object value;

        Entry(ThreadLocal<?> k, Object v) {
            super(k);
            value = v;
        }
    }

    /**
     * The initial capacity -- MUST be a power of two.
     */
    private static final int INITIAL_CAPACITY = 16;

    /**
     * The table, resized as necessary.
     * table.length MUST always be a power of two.
     */
    private Entry[] table;

    /**
     * The number of entries in the table.
     */
    private int size = 0;

    /**
     * The next size value at which to resize.
     */
    private int threshold; // Default to 0

    /**
     * Set the resize threshold to maintain at worst a 2/3 load factor.
     */
    private void setThreshold(int len) {
        threshold = len * 2 / 3;
    }

    /**
     * Increment i modulo len.
     */
    private static int nextIndex(int i, int len) {
        return ((i + 1 < len) ? i + 1 : 0);
    }

    /**
     * Decrement i modulo len.
     */
    private static int prevIndex(int i, int len) {
        return ((i - 1 >= 0) ? i - 1 : len - 1);
    }

ThreadLocalMap被定义为一个静态类,上面是包含的主要成员:

  1. 首先是Entry的定义,前面已经说过;
  2. 初始的容量为INITIAL_CAPACITY = 16
  3. 主要数据结构就是一个Entry的数组table;
  4. size用于记录Map中实际存在的entry个数;
  5. threshold是扩容上限,当size到达threashold时,需要resize整个Map,threshold的初始值为len * 2 / 3
  6. nextIndex和prevIndex则是为了安全的移动索引,后面的函数里经常用到。

而ThreadLocalMap的构造函数如下:

ThreadLocalMap(ThreadLocal<?> firstKey, Object firstValue) {
    table = new Entry[INITIAL_CAPACITY];
    int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);
    table[i] = new Entry(firstKey, firstValue);
    size = 1;
    setThreshold(INITIAL_CAPACITY);
}

就是使用firstKey和firstValue创建一个Entry,计算好索引i,然后把创建好的Entry插入table中的i位置,再设置好size和threshold。

最后说get函数完成实质性功能的getEntry函数:

map.getEntry

private Entry getEntry(ThreadLocal<?> key) {
    int i = key.threadLocalHashCode & (table.length - 1);
    Entry e = table[i];
    if (e != null && e.get() == key)
        return e;
    else
        return getEntryAfterMiss(key, i, e);
}
  1. 首先是计算索引位置i,通过计算key的hash%(table.length-1)得出;
  2. 根据获取Entry,如果Entry存在且Entry的key恰巧等于ThreadLocal,那么直接返回Entry对象;
  3. 否则,也就是在此位置上找不到对应的Entry,那么就调用getEntryAfterMiss。

这么一深入,引出了getEntryAfterMiss方法:

getEntryAfterMiss

private Entry getEntryAfterMiss(ThreadLocal<?> key, int i, Entry e) {
    Entry[] tab = table;
    int len = tab.length;

    while (e != null) {
        ThreadLocal<?> k = e.get();
        if (k == key)
            return e;
        if (k == null)
            expungeStaleEntry(i);
        else
            i = nextIndex(i, len);
        e = tab[i];
    }
    return null;
}

这个方法我们还得结合上一步看,上一步是因为不满足e != null && e.get() == key才沦落到调用getEntryAfterMiss的,所以首先e如果为null的话,那么getEntryAfterMiss还是直接返回null的,如果是不满足e.get() == key,那么进入while循环,这里是不断循环,如果e一直不为空,那么就调用nextIndex,不断递增i,在此过程中一直会做两个判断:

  1. 如果k==key,那么代表找到了这个所需要的Entry,直接返回;
  2. 如果k==null,那么证明这个Entry中key已经为null,那么这个Entry就是一个过期对象,这里调用expungeStaleEntry清理该Entry。 这里解答了前面留下的一个坑,即ThreadLocal Ref销毁时,ThreadLocal实例由于只有Entry中的一条弱引用指着,那么就会被GC掉,Entry的key没了,value可能会内存泄露的,其实在每一个get,set操作时都会不断清理掉这种key为null的Entry的。

为什么循环查找?

这里你可以直接跳到下面的set方法,主要是因为处理哈希冲突的方法,我们都知道HashMap采用拉链法处理哈希冲突,即在一个位置已经有元素了,就采用链表把冲突的元素链接在该元素后面,而ThreadLocal采用的是开放地址法,即有冲突后,把要插入的元素放在要插入的位置后面为null的地方,具体关于这两种方法的区别可以参考:解决哈希(HASH)冲突的主要方法。所以上面的循环就是因为我们在第一次计算出来的i位置不一定存在key与我们想查找的key恰好相等的Entry,所以只能不断在后面循环,来查找是不是被插到后面了,直到找到为null的元素,因为若是插入也是到null为止的。

分析完循环的原因,其实也可以深入expungeStaleEntry看看是怎么清理的。

expungeStaleEntry

private int expungeStaleEntry(int staleSlot) {
    Entry[] tab = table;
    int len = tab.length;

    // expunge entry at staleSlot
    tab[staleSlot].value = null;
    tab[staleSlot] = null;
    size--;

    // Rehash until we encounter null
    Entry e;
    int i;
    for (i = nextIndex(staleSlot, len);
         (e = tab[i]) != null;
         i = nextIndex(i, len)) {
        ThreadLocal<?> k = e.get();
        if (k == null) {
            e.value = null;
            tab[i] = null;
            size--;
        } else {
            int h = k.threadLocalHashCode & (len - 1);
            if (h != i) {
                tab[i] = null;

                // Unlike Knuth 6.4 Algorithm R, we must scan until
                // null because multiple entries could have been stale.
                while (tab[h] != null)
                    h = nextIndex(h, len);
                tab[h] = e;
            }
        }
    }
    return i;
}

看上面这段代码主要有两部分:

  1. expunge entry at staleSlot:这段主要是将i位置上的Entry的value设为null,Entry的引用也设为null,那么系统GC的时候自然会清理掉这块内存;
  2. Rehash until we encounter null: 这段就是扫描位置staleSlot之后,null之前的Entry数组,清除每一个key为null的Entry,同时若是key不为空,做rehash,调整其位置。

为什么要做rehash呢?

因为我们在清理的过程中会把某个值设为null,那么这个值后面的区域如果之前是连着前面的,那么下次循环查找时,就会只查到null为止。

举个例子就是:...,<key1(hash1), value1>, <key2(hash1), value2>,...(即key1和key2的hash值相同) 此时,若插入<key3(hash2), value3>,其hash计算的目标位置被<key2(hash1), value2>占了,于是往后寻找可用位置,hash表可能变为: ..., <key1(hash1), value1>, <key2(hash1), value2>, <key3(hash2), value3>, ... 此时,若<key2(hash1), value2>被清理,显然<key3(hash2), value3>应该往前移(即通过rehash调整位置),否则若以key3查找hash表,将会找不到key3

set方法

我们在get方法的循环查找那里也大概描述了set方法的思想,即开放地址法,下面看具体代码:

public void set(T value) {
    Thread t = Thread.currentThread();
    ThreadLocalMap map = getMap(t);
    if (map != null)
        map.set(this, value);
    else
        createMap(t, value);
}

首先也是获取当前线程,根据线程获取到ThreadLocalMap,若是有ThreadLocalMap,则调用map.set(ThreadLocal<?> key, Object value),若是没有则调用createMap创建。

map.set

private void set(ThreadLocal<?> key, Object value) {
    Entry[] tab = table;
    int len = tab.length;
    int i = key.threadLocalHashCode & (len-1);
    
    for (Entry e = tab[i];
         e != null;
         e = tab[i = nextIndex(i, len)]) {
        ThreadLocal<?> k = e.get();
    
        if (k == key) {
            e.value = value;
            return;
        }
    
        if (k == null) {
            replaceStaleEntry(key, value, i);
            return;
        }
    }

    tab[i] = new Entry(key, value);
    int sz = ++size;
    if (!cleanSomeSlots(i, sz) && sz >= threshold)
        rehash();
}

看上面这段代码:

  1. 首先还是根据key计算出位置i,然后查找i位置上的Entry,
  2. 若是Entry已经存在并且key等于传入的key,那么这时候直接给这个Entry赋新的value值。
  3. 若是Entry存在,但是key为null,则调用replaceStaleEntry来更换这个key为空的Entry
  4. 不断循环检测,直到遇到为null的地方,这时候要是还没在循环过程中return,那么就在这个null的位置新建一个Entry,并且插入,同时size增加1。
  5. 最后调用cleanSomeSlots,这个函数就不细说了,你只要知道内部还是调用了上面提到的expungeStaleEntry函数清理key为null的Entry就行了,最后返回是否清理了Entry,接下来再判断sz>thresgold,这里就是判断是否达到了rehash的条件,达到的话就会调用rehash函数。

上面这段代码有两个函数还需要分析下,首先是

replaceStaleEntry

private void replaceStaleEntry(ThreadLocal<?> key, Object value,
                               int staleSlot) {
    Entry[] tab = table;
    int len = tab.length;
    Entry e;

    for (int i = prevIndex(staleSlot, len);
         (e = tab[i]) != null;
         i = prevIndex(i, len))
        if (e.get() == null)
            slotToExpunge = i;

    for (int i = nextIndex(staleSlot, len);
         (e = tab[i]) != null;
         i = nextIndex(i, len)) {
        ThreadLocal<?> k = e.get();
        if (k == key) {
            e.value = value;

            tab[i] = tab[staleSlot];
            tab[staleSlot] = e;

            // Start expunge at preceding stale entry if it exists
            if (slotToExpunge == staleSlot)
                slotToExpunge = i;
            cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
            return;
        }

        if (k == null && slotToExpunge == staleSlot)
            slotToExpunge = i;
    }

    // If key not found, put new entry in stale slot
    tab[staleSlot].value = null;
    tab[staleSlot] = new Entry(key, value);

    // If there are any other stale entries in run, expunge them
    if (slotToExpunge != staleSlot)
        cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
}

首先我们回想上一步是因为这个位置的Entry的key为null才调用replaceStaleEntry。

  1. 第1个for循环:我们向前找到key为null的位置,记录为slotToExpunge,这里是为了后面的清理过程,可以不关注了;
  2. 第2个for循环:我们从staleSlot起到下一个null为止,若是找到key和传入key相等的Entry,就给这个Entry赋新的value值,并且把它和staleSlot位置的Entry交换,然后调用CleanSomeSlots清理key为null的Entry。
  3. 若是一直没有key和传入key相等的Entry,那么就在staleSlot处新建一个Entry。函数最后再清理一遍空key的Entry。

说完replaceStaleEntry,还有个重要的函数是rehash以及rehash的条件:

首先是sz > threshold时调用rehash

rehash

private void rehash() {
    expungeStaleEntries();

    // Use lower threshold for doubling to avoid hysteresis
    if (size >= threshold - threshold / 4)
        resize();
}

清理完空key的Entry后,如果size大于3/4的threshold,则调用resize函数:

private void resize() {
    Entry[] oldTab = table;
    int oldLen = oldTab.length;
    int newLen = oldLen * 2;
    Entry[] newTab = new Entry[newLen];
    int count = 0;

    for (int j = 0; j < oldLen; ++j) {
        Entry e = oldTab[j];
        if (e != null) {
            ThreadLocal<?> k = e.get();
            if (k == null) {
                e.value = null; // Help the GC
            } else {
                int h = k.threadLocalHashCode & (newLen - 1);
                while (newTab[h] != null)
                    h = nextIndex(h, newLen);
                newTab[h] = e;
                count++;
            }
        }
    }

    setThreshold(newLen);
    size = count;
    table = newTab;
}

由源码我们可知每次扩容大小扩展为原来的2倍,然后再一个for循环里,清除空key的Entry,同时重新计算key不为空的Entry的hash值,把它们放到正确的位置上,再更新ThreadLocalMap的所有属性。

remove

最后一个需要探究的就是remove函数,它用于在map中移除一个不用的Entry。也是先计算出hash值,若是第一次没有命中,就循环直到null,在此过程中也会调用expungeStaleEntry清除空key节点。代码如下:

private void remove(ThreadLocal<?> key) {
    Entry[] tab = table;
    int len = tab.length;
    int i = key.threadLocalHashCode & (len-1);
    for (Entry e = tab[i];
         e != null;
         e = tab[i = nextIndex(i, len)]) {
         if (e.get() == key) {
            e.clear();
            expungeStaleEntry(i);
            return;
        }
    }
}

使用ThreadLocal的最佳实践

我们发现无论是set,get还是remove方法,过程中key为null的Entry都会被擦除,那么Entry内的value也就没有强引用链,GC时就会被回收。那么怎么会存在内存泄露呢?但是以上的思路是假设你调用get或者set方法了,很多时候我们都没有调用过,所以最佳实践就是*

1 .使用者需要手动调用remove函数,删除不再使用的ThreadLocal.

2 .还有尽量将ThreadLocal设置成private static的,这样ThreadLocal会尽量和线程本身一起消亡。