简单看一看ThreadLocal 源码

314 阅读6分钟

简单看一看ThreadLocal 源码

常用方法

 ThreadLocal<String> threadLocal = new ThreadLocal<>();
 threadLocal.set("aa");
 threadLocal.get();
 threadLocal.remove();

ThreadLocal 数据结构

image.png

Thread类有一个类型为ThreadLocal.ThreadLocalMap的实例变量threadLocals,也就是说每个线程有一个自己的ThreadLocalMap

ThreadLocalMap有自己的独立实现,可以简单地将它的key视作ThreadLocalvalue为代码中放入的值(实际上key并不是ThreadLocal本身,而是它的一个弱引用)。

每个线程在往ThreadLocal里放值的时候,都会往自己的ThreadLocalMap里存,读也是以ThreadLocal作为引用,在自己的map里找对应的key,从而实现了线程隔离

ThreadLocalMap有点类似HashMap的结构,只是HashMap是由数组+链表实现的,而ThreadLocalMap中并没有链表结构,Hash冲突使用的是线性探测法。 我们还要注意Entry, 它的keyThreadLocal<?> k ,继承自WeakReference, 也就是我们常说的弱引用类型。

  • 强引用:我们常常new出来的对象就是强引用类型,只要强引用存在,垃圾回收器将永远不会回收被引用的对象,哪怕内存不足的时候
  • 软引用:使用SoftReference修饰的对象被称为软引用,软引用指向的对象在内存要溢出的时候被回收
  • 弱引用:使用WeakReference修饰的对象被称为弱引用,只要发生垃圾回收,若这个对象只被弱引用指向,那么就会被回收
  • 虚引用:虚引用是最弱的引用,在 Java 中使用 PhantomReference 进行定义。虚引用中唯一的作用就是用队列接收对象即将死亡的通知

内存泄露问题

如果我们的ThreadLocal对象没有强引用,那么弱引用的key就会被回收,但是value没有被回收。线程不退出的话,value会一直存在,这种情况会出现内存泄露。因为一个线程有1个ThreadLocalMap,所以ThreadLocalMap的生命周期是和Thread相同的。

image.png

set方法

大致逻辑: ThreadLocal的set()方法 set() 方法调用会先判断当前线程是否存在 Thread类的成员变量的ThreadLocalMap。 1.map == null,则会初始化一个Thread类的ThreadLocalMap ,然后按照ThreadLocal做key计算hash得到数组内的index,因为是第一次初始化所以不存在Hash冲突,所以直接插入。 2.map != null,直接获取ThreadLocalMap,调用Map的set() 方法,执行相关逻辑。 ThreadLcoalMap的set()方法 计算Hash得到数组的index

  1. index节点为空,直接new Entry(key,value),然后插入
  2. index节点不为空,此时出现hash冲突。向后遍历``nextIndex()搜索,直到index节点不为null之后,会直接new Entry(key,value),然后按照table[I]插入。i的值是变化的,搜索的过程中按照顺序会有以下情况
    • key值相等,说明是同一个ThreadLocal对象,直接更新value值
    • key==null,可能存在引用过期的情况,执行replaceStaleEntry()方法,替换过期数据。以当前节点向前迭代,标记过期位置staleSlot。。
  3. 插入之后会执行 cleanSomeSlots 清理方法,并返回是否需要扩容。负载因子也是 0.75。扩容2倍。

image.png

ThreadLocalMap 不存在的情况下 Set()

我们先看下,当我们第一次使用Thread中的ThreadLocal进行set时的流程。

ThreadLocal.class 类中
    
public void set(T value) {
       // 获得当前线程
        Thread t = Thread.currentThread();
       // 按照线程获取当前map
        ThreadLocalMap map = getMap(t);
        if (map != null)
            map.set(this, value);
        else
            // 不存在就初始化一个
            createMap(t, value);
    }


void createMap(Thread t, T firstValue) {
        t.threadLocals = new ThreadLocalMap(this, firstValue);
}

 ThreadLocalMap(ThreadLocal<?> firstKey, Object firstValue) {
            table = new Entry[INITIAL_CAPACITY];
           // hash计算 数组下标
            int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);
            // key 是 一个弱引用
            table[i] = new Entry(firstKey, firstValue);
            size = 1;
            setThreshold(INITIAL_CAPACITY);
 }

如果解决ThreadLocalMap的Hash冲突?

通过简单的看 set() 方法中的初始化 Map方法。ThreadLocalMap 是一个 hash存储的 Map。既然是hash肯定有解决hash冲突的步骤。

// 计算hash获取数组下标值
int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);


public class ThreadLocal<T> {
    private final int threadLocalHashCode = nextHashCode();

    private static AtomicInteger nextHashCode = new AtomicInteger();

    private static final int HASH_INCREMENT = 0x61c88647;

    private static int nextHashCode() {
        return nextHashCode.getAndAdd(HASH_INCREMENT);
    }
    
    static class ThreadLocalMap {
        ThreadLocalMap(ThreadLocal<?> firstKey, Object firstValue) {
            table = new Entry[INITIAL_CAPACITY];
            int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);

            table[i] = new Entry(firstKey, firstValue);
            size = 1;
            setThreshold(INITIAL_CAPACITY);
        }
    }
}

从代码中可以看出来,我们的关键是threadLocalHashCode 。我们每new一个ThreadLocal 对象 它的 threadLocalHashCode 就会累加 HASH_INCREMENT = 0x61c88647 大小。、这个值很特殊,它是斐波那契数 也叫 黄金分割数hash增量为 这个数字,带来的好处就是 hash 分布非常均匀。感兴趣的可以搜一下 斐波那契算法

解决办法 ** 我们模拟插入一个value=27的值,hash结果index=4,但是index=4已经存在值,所以会向后遍历查找为null的节点,发现index=8还没有值,然后插入对应值。这就是经典的hash冲突解决算法,线性探测法image.png

ThreadLocalMap 存在的情况下 Set()

回归到我们的set方法,我们经过第一次初始化 ThreadLocalMap后,再进行 set() 操作的时候,执行流程是什么样的那?通过观察源码主要逻辑在

ThreadLocalMap.java

private void set(ThreadLocal<?> key, Object value) {

            Entry[] tab = table;
            int len = tab.length;
            // 计算hash得到 数组 index 
            int i = key.threadLocalHashCode & (len-1);
			// 从当前index 迭代 数组 table[i]==null 退出循环
            for (Entry e = tab[i];
                 e != null;
                 e = tab[i = nextIndex(i, len)]) {
                ThreadLocal<?> k = e.get();
				// key相等,代表同一个 threadlocal 对象,直接更新 value
                if (k == key) {
                    e.value = value;
                    return;
                }
				// 节点key==null,说明探测到过期节点
                if (k == null) {
                	// 此处 staleSlot 过期节点位置 
                    replaceStaleEntry(key, value, i);
                    return;
                }
            }
			
            // 新插入或覆盖节点
            tab[i] = new Entry(key, value);
            int sz = ++size;
            if (!cleanSomeSlots(i, sz) && sz >= threshold)
                rehash();
        }

for 循环逻辑:

  1. 遍历当前key值对应的桶中Entry数据为空,这说明散列数组这里没有数据冲突,跳出for循环,直接set数据到对应的桶中
  2. 如果key值对应的桶中Entry数据不为空
    1. k = key,说明当前set`操作是一个替换操作,做替换逻辑,直接返回
    2. key = null,说明当前桶位置的Entry是过期数据,执行replaceStaleEntry()方法(核心方法),然后返回
  3. for循环执行完毕,继续往下执行说明向后迭代的过程中遇到了entrynull的情况
    1. Entrynull的桶中创建一个新的Entry对象
    2. ++size操作
  4. 调用cleanSomeSlots()做一次启发式清理工作,清理散列数组中Entrykey过期的数据
    1. 果清理工作完成后,未清理到任何数据,且size超过了阈值(数组长度的2/3),进行rehash()操作
    2. rehash()中会先进行一轮探测式清理,清理过期key`,清理完成后如果size >= threshold - threshold / 4,就会执行真正的扩容逻辑(扩容逻辑往后看)

get方法

get方法相对就简单了许多,大体逻辑:

获取Thread的ThreadLocalMap对象。 Map == null ,执行set初始值方法。 Map != null , 执行 getEntry方法。getEntry方法:按照hash计算index值,table[i]的key相等,则直接返回不相等则代表之前产生hash冲突,向后遍历直到table[i]==null查找结束或者查找key相等的节点返回value值。如果遍历过程中发现 table[i].key 为null的过期节点,会进行探测式数据回收。

image.png 我们以get(ThreadLocal1)为例,通过hash计算后,正确的slot位置应该是4,而index=4的槽位已经有了数据,且key值不等于ThreadLocal1,所以需要继续往后迭代查找。 迭代到index=5的数据时,此时Entry.key=null,触发一次探测式数据回收操作,执行expungeStaleEntry()方法,执行完后,index 5,8的数据都会被回收,而index 6,7的数据都会前移,此时继续往后迭代,到index = 6的时候即找到了key值相等的Entry数据,如下图所示:

image.png

参考:www.jianshu.com/p/134d72d37…