【JUC源码】ThreadLocal源码深析&使用示例

247 阅读8分钟

在开始看源码之前,我们必须要知道ThreadLocal有什么作用:ThreadLocal 使同一个变量在不同线程间隔离,即每个线程都可以有自己独立的副本,然后可以在该线程的方法间共享(随时取出使用)。不明白的话可以看文章第三部分的使用示例。

这其实是一种空间换时间的思路,因为如果每个线程都有自己独立的副本,就不用通过加锁使线程串行化执行去保证线程安全了,节省了时间,但作为代价要为每个线程开辟一块独立的空间。

了解了 ThreadLocal 的功能后,那我们该如何设计ThreadLocal?

1.如何设计线程间隔离

首先,很容易想到每个线程都必须为 ThreadLocal 开辟一块单独的内存,但仅仅开辟一块大小等于ThreadLocal的内存是不够的的。因为一个线程可能有多个独立的副本,换句话说就是可以在多个类中创建 ThreadLocal,比如:

public class A {
	private static ThreadLocal<String> threadLocal1 = new ThreadLocal<>();
	
	public void f1(){ threadLocal1.set("test");
}
public class B {
	private static ThreadLocal<Integer> threadLocal2 = new ThreadLocal<>();

	public void f2(){ threadLocal1.set(001);
}

那么对于执行的线程来说,test和001都是它的独立副本,都要保存起来,而他俩的区别就在于具体ThreadLocal对象不同。

接下来,我们就看看在线程(Thread类)中到底是如何保存ThreadLocal的: 在这里插入图片描述 可以看到,每个Thread维护一个ThreadLocalMap,而存储在ThreadLocalMap内的就是一个以Entry为元素的table数组(Entry就是一个key-value结构:key为ThreadLocal,value为存储的值),所以我们可以得到以下两点信息:

  1. 数组保证了每个线程可以存储多个独立的副本
  2. Entry 提供了区分不同副本方式,即ThreadLocal不同

另外,虽然这里有两个变量,但只有 threadLocals 是直接进行set/get操作的。若在父线程中创建子线程,会拷贝父线程的inheritableThreadLocals到子线程。

看源码前的要理解的逻辑终于说完了,下面进入正戏......

2.ThreadLocal

ThreadLocal 核心成员变量及主要构造函数:

// ThreadLocal使用了泛型,所以可以存放任何类型
public class ThreadLocal<T> {
    
    // 当前 ThreadLocal 的 hashCode,作用是计算当前 ThreadLocal 在 ThreadLocalMap 中的索引位置
    // nextHashCode() { return nextHashCode.getAndAdd(HASH_INCREMENT);}
    private final int threadLocalHashCode = nextHashCode();
   
    // nextHashCode 直接决定 threadLocalHashCode(= nextHashCode++)
    // 这么做因为ThreadLocal可能在不同类中new出来多个,但线程只有一个,若每次下标都从同一位置开始,虽然有hash碰撞处理策略,但仍然会影响效率
    // static:保证了nextHashCode的唯一性,间接保证了threadHashCode唯一性
    private static AtomicInteger nextHashCode = new AtomicInteger();
    
    static class ThreadLocalMap{...}
    
    // 只有空参构造
    public ThreadLocal() {
    }
	
	// 计算 ThreadLocal 的 hashCode 值,就是通过CAS让 nextHashCode++
	private static int nextHashCode() {
    	return nextHashCode.getAndAdd(HASH_INCREMENT);
	}
	
	//......
}

2.1 set()

拿到当前线程的 threadLocals 并将 Entry(当前ThreadLocal对象,value)放入。另外,因为 set 操作每个线程都是串行的,所以不会有线程安全的问题

public void set(T value) {
    // 拿到当前线程
    Thread t = Thread.currentThread();
    // 拿到当前线程的ThreadLocalMap,即threadLocals变量
    ThreadLocalMap map = getMap(t);
    
    // 当前 thradLocal 非空,即之前已经有独立的副本数据了
    if (map != null)
        map.set(this, value); // 直接将当前 threadLocal和value传入
    // 当前threadLocal为空
    else
        createMap(t, value); // 初始化ThreadLocalMap
}

ThreadLocalMap getMap(Thread t) {
    return t.threadLocals;
}

2.2 get()

在当前线程的 theadLocals 中获取当前ThreadLocal对象对应的value

  1. 在当前线程拿到threadLocals
  2. 若threadLocals=null,则将其初始化
  3. 通过当前ThreadLocal对象获取到相应Entry
    • entry != null ,返回result
    • entry = null ,返回null
public T get() {
    // 拿出当前线程
    Thread t = Thread.currentThread();
    // 从线程中拿到 threadLocals(ThreadLocalMap)
    ThreadLocalMap map = getMap(t);
    
    if (map != null) {
        // 从 map 中拿到相应entry
        ThreadLocalMap.Entry e = map.getEntry(this);
        // 如果不为空,读取当前 ThreadLocal 中保存的值
        if (e != null) {
            @SuppressWarnings("unchecked")
            T result = (T)e.value;
            return result;
        }
    }
    // 否则给当前线程的 ThreadLocal 初始化,并返回初始值 null
    return setInitialValue();
}

3.ThreadLocalMap

  • 虽然是内部类,但 ThreadLocalMap 不像 List 的 Node 是 List 的组成部分(List > Node)
  • ThreadLocalMap是用来给Thread作为属性,并保存ThreadLocal的 (Thread > ThreadLocalMap > ThreadLocal)
// 静态内部类,可直接被外部调用
static class ThreadLocalMap {
        // Entry(k,v)
    	// k = WeakReference 是弱引用,当没有引用指向时,会直接被回收
        static class Entry extends WeakReference<ThreadLocal<?>> {
            // 当前 ThreadLocal 关联的值
            Object value;
            // WeakReference 的引用 referent 就是 ThreadLocal
            Entry(ThreadLocal<?> k, Object v) {
                super(k);
                value = v;
            }
        }
    	
    	// 存储 (ThreadLocal,Obj) 的数组
        private Entry[] table;
        // 数组的初始化大小
        private static final int INITIAL_CAPACITY = 16; 
        // 扩容的阈值,默认是数组大小的三分之二
        private int threshold;
		
		//.......
}

3.1 set()

将 Entry(threadLocal,Object Value)放入 threadLocals的数组

  1. 获取到 threadLocals 的数组
  2. 计算当前ThreadLocal对应的数组下标
  3. 将Entry(threadLocal,Object Value)放入数组
    • 无hash碰撞,new Entry放入
    • 若出现hash碰撞,则i++,直到找到没有Entry的位置,new Entry放入
    • 若碰见key相同(ThreadLocal),则替换value
  4. 判断是否需要扩容
private void set(ThreadLocal<?> key, Object value) {
    // 1.拿到当前threadLocals的数组
    Entry[] tab = table;
    int len = tab.length;
    // 2.计算当前 ThreadLocal 在数组中的下标,其实就是 ThreadLocal 的 hashCode 和数组大小-1取余
    int i = key.threadLocalHashCode & (len-1);
	
	// 可以看到循环的结束条件是 tab[i]==null,即无哈希冲突
	// 若出现哈希冲突时,依次向后(i++)寻找空槽点。nextIndex方法就是让在不超过数组长度的基础上,把数组的索引位置 + 1
    for (Entry e = tab[i]; e != null; e = tab[i = nextIndex(i, len)]) {
        ThreadLocal<?> k = e.get();
        // 找到内存地址一样的 ThreadLocal,直接替换
        if (k == key) {
            e.value = value;
            return;
        }
        // 当前 key 是 null,说明 ThreadLocal 被清理了,直接替换掉
        if (k == null) {
            replaceStaleEntry(key, value, i);
            return;
        }
    }
    // 当前 i 位置是无值的,可以被当前 thradLocal 使用
    tab[i] = new Entry(key, value);
    int sz = ++size;
    
    // 当数组大小大于等于扩容阈值(数组大小的三分之二)时,进行扩容
    if (!cleanSomeSlots(i, sz) && sz >= threshold)
        rehash();
}

3.2 getEntry()

获取相应节点Entry

  1. 计算当前ThreadLocal对应的索引位置(hashcode 取模数组大小-1 )
  2. 若 e != null,返回当前Entry
  3. 若 e == null 或 有但key(ThreadLocal)不符,调用 getEntryAfterMiss 自旋进行寻找
private Entry getEntry(ThreadLocal<?> key) {
    // 计算索引位置:ThreadLocal 的 hashCode 取模数组大小-1
    int i = key.threadLocalHashCode & (table.length - 1);
    Entry e = table[i];
    
    // e 不为空 && e 的 ThreadLocal 的内存地址和 key 相同
    if (e != null && e.get() == key)
        return e; // 直接返回
    // 因为上面解决Hash冲突的方法是i++,所以会出现计算出的槽点为空或者不等于当前ThreadLocal的情况
    else
        return getEntryAfterMiss(key, i, e); // 继续通过 getEntryAfterMiss 方法找
}

getEntryAfterMiss:根据 thradLocalMap set 时解决数组索引位置冲突的逻辑,该方法的寻找逻辑也是对应的,即自旋 i+1,直到找到为止

private Entry getEntryAfterMiss(ThreadLocal<?> key, int i, Entry e) {
    Entry[] tab = table;
    int len = tab.length;
    // 在大量使用不同 key 的 ThreadLocal 时,其实还蛮耗性能的
    while (e != null) {
        ThreadLocal<?> k = e.get();
        // 内存地址一样,表示找到了
        if (k == key)
            return e;
        // 删除没用的 key
        if (k == null)
            expungeStaleEntry(i);
        // 继续使索引位置 + 1
        else
            i = nextIndex(i, len);
        e = tab[i];
    }
    return null;
}

3.3 resize()

ThreadLocalMap 中的 ThreadLocal 的个数超过阈值时,ThreadLocalMap 就要开始扩容了

  1. 拿到threadLocals的table
  2. 初始化新数组,大小为原来2倍
  3. 将老数组拷贝到新数组
    • 根据key(ThreadLocal)计算新的索引位置
    • 若出现hash碰撞,i++
  4. 计算新的扩容阈值,将新数组赋给table
private void resize() {
    // 1.拿出旧的数组
    Entry[] oldTab = table;
    int oldLen = oldTab.length;
    // 2.计算新数组的大小,为老数组的两倍
    int newLen = oldLen * 2;
    // 初始化新数组
    Entry[] newTab = new Entry[newLen];
    int count = 0;
    
    // 3.老数组的值拷贝到新数组上
    for (int j = 0; j < oldLen; ++j) {
        Entry e = oldTab[j];
        if (e != null) {
            ThreadLocal<?> k = e.get();
            if (k == null) {
                e.value = null; // Help the GC
            } else {
                // 计算 ThreadLocal 在新数组中的位置
                int h = k.threadLocalHashCode & (newLen - 1);
                // 如果出现哈希冲突,即索引 h 的位置值不为空,往后+1,直到找到值为空的索引位置
                while (newTab[h] != null)
                    h = nextIndex(h, newLen);
                // 给新数组赋值
                newTab[h] = e;
                count++;
            }
        }
    }
    // 4.计算新数组下次扩容阈值,为数组长度的三分之二
    setThreshold(newLen);
    size = count;
    table = newTab;
}

4.使用示例

下面启动 9 个子线程,每个子线程都将名字(thread - i)存在同一个变量中,然后再打印出来。可以想到,如果同一个变量可以做到线程间隔离(互补影响),控制台正确的结果应该是 thread - 0 到 thread - 8。

下面就分别演示演示这个变量的两种实现:1.普通变量String ,2.ThreadLocal<String>

4.1 普通变量

public class StringTest {
	// 保存线程名的普通变量value
    private String value;
	// 不直接设置value,而是暴露出get和set方法
    private String getString() { return string; }
    private void setString(String string) { this.string = string; }

    public static void main(String[] args) {
    
        StringTest test= new StringTest ();
        
        int threads = 9; // 要启动的线程个数
        CountDownLatch countDownLatch = new CountDownLatch(threads); // countDownLatch 用于防止主线程在子线程未完成前结束
       
        // 启动9个子线程
        for (int i = 0; i < threads; i++) {
            Thread thread = new Thread(() -> {
                test.setString(Thread.currentThread().getName()); // 向变量value中存入线程名 thread - i
                System.out.println(test.getString()); // 然后打印出来。注:这里可能存在并发
                countDownLatch.countDown(); // 门栓-1
            }, "thread - " + i); 
            thread.start();
        }
    }

	countDownLatch.await(); // 等countDownLatch为0时,主线程恢复运行
}

结果如下:

thread - 1
thread - 2
thread - 1
thread - 3
thread - 4
thread - 5
thread - 6
thread - 7
thread - 8

可以看到没有 thread - 0,反而 thread - 1 出现了两次,所以使用普通类型的变量无法实现同一变量对于不同线程隔离

4.2 ThreadLocal

使用ThreadLocal时,一般声明为static

  • 一个类一个ThreadLocal--> 当前线程一个Entry 就够了
  • 使用时调用方便
public class ThreadLocalStringTest {
	// 保存线程名的ThreadLocal变量threadLocal
	// 注:这里除了是String,也可是别的任何类型(Integer,List,Map...)
    private static ThreadLocal<String> threadLocal = new ThreadLocal<>();
	// 不直接操作 threadLocal,而是封装成 set/get 方法
    private String getString() { return threadLocal.get(); }
    private void setString(String string) { threadLocal.set(string);}

    public static void main(String[] args) {
    
        ThreadLocalStringTest test= new ThreadLocalStringTest();
        
        int threads = 9; // 要创建的子线程个数
        CountDownLatch countDownLatch = new CountDownLatch(threads); // countDownLatch 用于防止主线程在子线程未完成前结束
		
        // 创建 9 个线程
        for (int i = 0; i < threads; i++) {
            Thread thread = new Thread(() -> {
                test.setString(Thread.currentThread().getName()); // 向ThreadLocal中存入当前线程名 thread - i
                System.out.println(test.getString()); // 向ThreadLocal获取刚存的线程名。注:可能存在并发
                countDownLatch.countDown(); // 门栓-1
            }, "thread - " + i);
            thread.start();
        }
		
		countDownLatch.await(); // 等countDownLatch为0时,主线程恢复运行
    }

}

运行结果:

thread - 0
thread - 1
thread - 2
thread - 3
thread - 4
thread - 5
thread - 6
thread - 7
thread - 8

可以看到运行结果符合预期,即ThreadLocal实现了同一变量在线程间隔离。另外ThreadLocal还可以用于单点登录,用来保存不同请求线程的token,然后在解析时取出。