1.前言
1.1 ThreadLocal基本原理
ThreadLocal
是 Java 中的一个非常有用的类,它提供了一种线程局部变量,即每个线程都可以访问到自己独立初始化过的变量副本,这个变量对其他线程是不可见的。最常见的用法就是用户请求携带用户ID请求某个接口的时候,在整个链路中需要用户信息的时候,通过AOP将用户信息查出来放到ThreadLocal当中。
其本质上是将共享变量放到每个线程的ThreadLocalMap成员变量中,更直白的说就是把这个变量的副本存到了Thread中。ThreadLocalMap是一个重写的HashMap,重写目的在于将Map的Key设置为弱引用(当发生GC时,不管内存空间是否充足,都会对弱引用的对象进行回收),方便垃圾回收。后面我们会具体讲到。
1.2 有关Thread的前置知识
// 与此线程相关的 ThreadLocal 值
ThreadLocal.ThreadLocalMap threadLocals = null;
// 与此线程相关的 InheritableThreadLocal 值
ThreadLocal.ThreadLocalMap inheritableThreadLocals = null;
这里我们可以看到Thread类实际上是有两个ThreadLocalMap类型的成员变量的。其中inheritableThreadLocals
的主要作用是用于父子线程之间的共享变量传递。
当我们创建Thread的时候,通过其构造函数,最终会执行到Thread的init()
方法:
public Thread(Runnable target) {
init(null, target, "Thread-" + nextThreadNum(), 0);
}
private void init(ThreadGroup g, Runnable target, String name,
long stackSize, AccessControlContext acc,
boolean inheritThreadLocals) {
...
// 如果允许子线程共享父线程的变量副本,并且父线程的变量副本集合不为空,那么子线程将复制一份父线程的变量副本集合
if (inheritThreadLocals && parent.inheritableThreadLocals != null)
this.inheritableThreadLocals =
ThreadLocal.createInheritedMap(parent.inheritableThreadLocals);
...
}
这里我们可以看到,子线程被创建的时候,会判断父线程的inheritThreadLocals
是否为null,如果不为null会复制一份。复制一份的目的就是为了能将父线程的变量副本传递给子线程。
2.ThreadLocal
2.1 ThreadLocal成员变量
// ThreadLocal的哈希值
private final int threadLocalHashCode = nextHashCode();
// 要给出的下一个哈希代码。以原子方式更新。从零开始。
private static AtomicInteger nextHashCode = new AtomicInteger();
// HASH_INCREMENT: 表示hash值的增量
// 每创建一个ThreadLocal对象,ThreadLocal.nextHashCode的值就会增长HASH_INCREMENT(0x61c88647)
// 这个值很特殊,它是斐波那契数也叫黄金分割数。hash增量为这个数字,带来的好处就是hash分布非常均匀。
private static final int HASH_INCREMENT = 0x61c88647;
/**
* 返回下一个Hash Code
*/
private static int nextHashCode() {
return nextHashCode.getAndAdd(HASH_INCREMENT);
}
前面我们也讲过了,共享变量实际上是存放在Thread的ThreadLocalMap中,并且这个Map的Key是ThreadLocal。
这里我们可以看到会调用nextHashCode()
方法生成ThreadLocal的哈希值,nextHashCode()
方法通过名为nextHashCode
的AtomicInteger
类型变量自增获取下一个HashCode。可以注意到这里nextHashCode
被static修饰,意味着是一个类变量。有关于AtomicInteger
大家可以去看看相关博客,其本质上是通过CAS实现自旋更新保证线程安全,避免了加互斥锁导致的资源开销。
2.2 关键源码
set()
public void set(T value) {
// 获取当前线程
Thread t = Thread.currentThread();
// 获取线程本地的ThreadLocalMap
ThreadLocalMap map = getMap(t);
if (map != null)
map.set(this, value);
else
createMap(t, value);
}
ThreadLocalMap getMap(Thread t) {
return t.threadLocals;
}
void createMap(Thread t, T firstValue) {
t.threadLocals = new ThreadLocalMap(this, firstValue);
}
ThreadLocalMap(ThreadLocal<?> firstKey, Object firstValue) {
// 初始化table,默认长度是16
table = new Entry[INITIAL_CAPACITY];
// 通过ThreadLocal对象的hashCode与INITIAL_CAPACITY-1进行与运算得出其应该存放元素的下标
int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);
// 将值存入散列表table中并更新size为1
table[i] = new Entry(firstKey, firstValue);
size = 1;
// 设置扩容阈值,初始值为len * 2 / 3
setThreshold(INITIAL_CAPACITY);
}
private void setThreshold(int len) {
threshold = len * 2 / 3;
}
这里我们可以看到,ThreadLocal的set方法实际上是获取了Thread本地的ThreadLocalMap,然后对这个map进行操作。这里如果判断Thread本地的ThreadLocalMap还没有进行初始化,那么就将进行初始化。
get()
public T get() {
// 获取当前线程
Thread t = Thread.currentThread();
// 得到当前线程的ThreadLocalMap,底层由哈希表实现
ThreadLocalMap map = getMap(t);
// (1)如果map已经初始化,则进行读操作
if (map != null) {
// 获取键值对对象
ThreadLocalMap.Entry e = map.getEntry(this);
if (e != null) {
@SuppressWarnings("unchecked")
T result = (T)e.value;
return result;
}
}
// (2)如果map没有进行初始化则调用setInitialValue方法初始化map并返回null作为结果
return setInitialValue();
}
private T setInitialValue() {
// 获取初始值null
T value = initialValue();
// 获取当前线程的ThreadLocalMap
Thread t = Thread.currentThread();
ThreadLocalMap map = getMap(t);
// 判断map是否已经初始化,未初始化则进行初始化
if (map != null)
map.set(this, value);
else
createMap(t, value);
return value;
}
3.ThreadLocalMap
前面已经提到了,ThreadLocalMap是一个重写的HashMap,目的就是为了将Key设置为弱引用,方便进行垃圾回收。我们可以看一下ThreadLocalMap的内部类Entry,Entry继承了WeakReference
,简单来说就是Entry弱引用于ThreadLocal对象。
static class Entry extends WeakReference<ThreadLocal<?>> {
Object value;
Entry(ThreadLocal<?> k, Object v) {
super(k);
value = v;
}
}
ThreadLocalMap在进行set/map操作以后会触发清除操作,ThreadLocalMap中的清除操作具体可以分为启发式清除和探测式清除,这个我们后面也会细讲。
3.1 核心成员变量
// 初始化当前map内部散列表数组的初始长度 16
private static final int INITIAL_CAPACITY = 16;
// threadLocalMap 内部散列表数组的引用,数组的长度必须是2的次方数
private Entry[] table;
//当前散列表数组占用情况,存放多少个Entry。
private int size = 0;
/**
* 扩容触发阈值,初始值为: len * 2/3
* 触发后调用 rehash() 方法。
* rehash() 方法先做一次全量检查全局过期数据,把散列表中所有过期的entry移除。
* 如果移除之后,当前散列表中的entry个数仍然达到(threshold - threshold/4),
* 即,当前threshold阈值的3/4就进行扩容。
*/
private int threshold;
3.2 nextIndex()与prevIndex()
在接下来的源码分析中,你会看到很多个循环,每个循环中的下标偏移基本上都是通过这两个方法完成的。提前了解这两个方法有助于我们更好的理解代码。
// 获得下一个数组下标,并且保证下标不会越界(环绕)
private static int nextIndex(int i, int len) {
return ((i + 1 < len) ? i + 1 : 0);
}
// 获得上一个数组下标,并且保证下标不会越界(环绕)
private static int prevIndex(int i, int len) {
return ((i - 1 >= 0) ? i - 1 : len - 1);
}
3.3 探测式清除与启发式清除
探测式清除:expungeStaleEntry()
探测式清除将遍历散列数组,从开始位置向后探测清理过期数据。
- 沿途遍历中如果遇到过期数据(key为null的Entry),则将其清空也就是将该槽位置为null
- 沿途遍历中如果遇到未过期数据则进行rehash,将其放到其应该存放的位置,如果遇到哈希冲突则将其放到最近的一个空槽中
这里我们要尤其注意,ThreadLocalMap采用开放地址法解决哈希冲突,删除的过期元素是多个冲突元素重的一个,删除以后需要将冲突的元素向前挪动,这么做的目的是为了寻找元素的时候,避免遇到null就停止寻找了,前面key=null的Entry已经被置为null,如果不移动的话后面因为开放地址法偏移的元素就无法被访问。
你可以结合这些图来理解下探测式清除。
第一步,清除掉沿途中过期的Entry
第二步,将偏离正确槽的Entry进行rehash,让它离正确槽更近一些
最后直到遇到为null的Entry停止探测式清除。
private int expungeStaleEntry(int staleSlot) {
Entry[] tab = table;
int len = tab.length;
// 将staleSlot位置的过期Entry置为null
tab[staleSlot].value = null;
tab[staleSlot] = null;
size--;
Entry e;
int i;
// 以staleSlot为起始下标,向后扫描,直到遇到为null的Entry停止扫描
for (i = nextIndex(staleSlot, len);
(e = tab[i]) != null;
i = nextIndex(i, len)) {
ThreadLocal<?> k = e.get();
// (1)如果遇到过期Entry,将其置为null
if (k == null) {
e.value = null;
tab[i] = null;
size--;
// (2)如果未过期的Entry,将其进行哈希计算,如果是因为开放地址法向后移动的Entry,则将其重新放回对应的位置
} else {
// ThreadLocalMap采用开放地址法解决哈希冲突,删除的过期元素是多个冲突元素重的一个,删除以后需要将冲突的元素向前挪动,
// 这么做的目的是通过开放地址法寻找元素的时候,避免遇到null就停止寻找了,前面key=null的Entry已经被置为null,如果不移动的话后面的元素就无法被访问
int h = k.threadLocalHashCode & (len - 1);
// 不相等说明hash是有冲突的
if (h != i) {
tab[i] = null;
while (tab[h] != null)
h = nextIndex(h, len);
tab[h] = e;
}
}
}
return i;
}
启发式清除:cleanSomeSlots()
探测式清理是以当前Entry 往后清理,遇到值为null则结束清理,属于线性探测清理。
而启发式清除则是非线性的扫描部分槽位,检测是否有过期数据。实际上他还是调用的探测式清除,只不过针对于探测式清除,这个起始下标并不是连续的。
private boolean cleanSomeSlots(int i, int n) {
// 标记是否有过期Entry被清除
boolean removed = false;
Entry[] tab = table;
int len = tab.length;
do {
i = nextIndex(i, len);
Entry e = tab[i];
// 有过期数据,则开始进行探测式清除,并且将清除标识removed置为true
if (e != null && e.get() == null) {
n = len;
removed = true;
i = expungeStaleEntry(i);
}
} while ( (n >>>= 1) != 0);
return removed;
}
3.4 set()
前面讲了那么多,现在才是本文的重点。ThreadLocal的set/get方法实际上都是调用了ThreadLocalMap的set/get方法。既然是哈希表,就会遇到哈希冲突,ThreadLocalMap通过开发地址法来解决哈希冲突。
整体的代码逻辑实际上分为四种情况:
- 通过哈希计算后的槽位对应的Entry为空
- 通过哈希计算后的槽位对应的Entry不为空,并且key相等
- 通过哈希计算后的槽位对应的Entry不为空,但是该Entry的key为null
- 通过哈希计算后的槽位对应的Entry不为空,并且该Entry的key不相等,通过线性探测法一直向后扫描一直没有找到所需要的Entry,直到遇到为null的Entry停下
了解了这四种情况,可以结合着这张图去理解这个逻辑:
这里假如有一个数据为27,经过哈希寻址,它将被存放在下标为4的槽中。
- 第一种情况和第四种情况将会直接创建新的Entry放到对应位置,并且会进行一次启发式清除,如果启发式清除没法扫描到过期元素,并且当前Entry数组容量超过了扩容阈值,则执行
rehash()
函数。 - 第二种情况直接进行值覆盖操作然后返回。
- 第三种情况则是调用replaceStaleEntry进行清除操作。
private void set(ThreadLocal<?> key, Object value) {
Entry[] tab = table;
int len = tab.length;
// 通过哈希计算获取对应槽位的下标
// 这里为什么是&(len - 1)而不是%len,这个其实是为了提高哈希寻址的效率,感兴趣的同学可以搜搜
int i = key.threadLocalHashCode & (len-1);
// 1.通过哈希计算后的槽位对应的Entry数据为空,走第四种情况对应的相同逻辑
for (Entry e = tab[i];
// 这个循环条件非常关键,一开始没有进入循环和进入循环后不满足条件跳出循环是两种情况
e != null;
e = tab[i = nextIndex(i, len)]) {
ThreadLocal<?> k = e.get();
// 2.槽位对应的Entry不为空,并且key相等
if (k == key) {
e.value = value;
return;
}
// 3.槽位对应的Entry不为空,但是该Entry的key为null
if (k == null) {
replaceStaleEntry(key, value, i);
return;
}
}
// 4.对应的槽位的Entry不为空,并且该Entry的key不相等,通过线性探测法一直向后扫描,直到遇见Entry为null的槽位停下,创建新的Entry
tab[i] = new Entry(key, value);
int sz = ++size;
// 启发式清除过期数据,如果未清理到任何数据则返回false,并判断size是否已经超过了扩容阈值,达到以后则进行一次rehash()
// reHash实际上还是进行一次探测式清除,并且清除完以后,会Entry数组的size(不是length)判断是否需要进行扩容
if (!cleanSomeSlots(i, sz) && sz >= threshold)
rehash();
}
看完了set方法大概的逻辑我们接着来看replaceStaleEntry()
方法的。
在看replaceStaleEntry()
方法之前,我们需要明确一个背景,这个方法是在set()
方法执行中遇到了过期数据的时候进入的。
replaceStaleEntry()
方法我们需要注意两个指针:staleSlot
和slotToExpunge
。
staleSlot
指向本方法进入时的过期数据的位置,从始至终不会改变slotToExpunge
则是指明清楚操作的初始下标,这个指针会更新。
整个replaceStaleEntry()
方法实际上由两个循环构成:
- 循环一:从当前的
staleSlot
向前遍历,查找其他过期的数据(key为null的Entry),不断地更新过期数据起始扫描下标slotToExpunge
,直到遇到为null的Entry停下 - 循环二:从当前
staleSlot
向后查找key值相等的Entry元素。如果找到,则将这个元素放到staleSlot
指向的位置
当第二段循环遇到为null的Entry以后,将会跳出循环,意味着没有找到对应的数据,那么将在staleSlot
的位置创建一个新的Entry。
在循环二中如果覆盖了值,或者在循环二结束后创建了新的值以后,都会开启一轮扫描清除操作。
具体的代码流程可以结合图片来理解。
如果在循环二中遇到了为null的Entry但是仍然没有找到对应的数据,那么就是下面这种情况:
private void replaceStaleEntry(ThreadLocal<?> key, Object value, int staleSlot) {
Entry[] tab = table;
int len = tab.length;
Entry e;
// 循环一
// 从当前的staleSlot向前遍历,查找其他过期的数据(key为null的Entry),不断地更新过期数据起始扫描下标slotToExpunge,直到遇到为null的Entry停下
int slotToExpunge = staleSlot;
for (int i = prevIndex(staleSlot, len);
(e = tab[i]) != null;
i = prevIndex(i, len)) {
if (e.get() == null)
slotToExpunge = i;
}
// 循环二
// 从当前staleSlot向后查找key值相等的Entry元素
// (1)如果找到则更新并将其替换到staleSlot的位置,
// (2)如果直到遇到一个为null的Entry,还没有找到对应的key相等的Entry,则创建一个新的Entry放到staleSlot的位置
for (int i = nextIndex(staleSlot, len);
(e = tab[i]) != null;
i = nextIndex(i, len)) {
ThreadLocal<?> k = e.get();
// 3.1 如果找到key相等的Entry元素,则覆盖值,并且将该Entry替换到staleSlot的位置
if (k == key) {
e.value = value;
tab[i] = tab[staleSlot];
tab[staleSlot] = e;
// 经历第一个循环以后,如果slotToExpunge == staleSlot仍然成立,那么意味着当前下标staleSlot前后都没有过期的Entry,直接开始清除操作
// 在这里就直接结束了该方法的执行
if (slotToExpunge == staleSlot)
slotToExpunge = i;
cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
return;
}
// key == null并且slotToExpunge == staleSlot,意味着循环一从staleSlot开始到staleSlot结束并没有找到这个过期数据,这个时候将slotToExpunge指向当前下标
if (k == null && slotToExpunge == staleSlot)
slotToExpunge = i;
}
// 3.2 如果上面的循环直到遇到了为null的Entry,仍然没有找到对应的key相等的Entry,意味着以该ThreadLocal为key的Entry不存在于数组中,则新建一个放进去
tab[staleSlot].value = null;
tab[staleSlot] = new Entry(key, value);
// 如果有其他已经过期的对象,则清理此过期对象
if (slotToExpunge != staleSlot)
cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
}
3.5 getEntry()
getEntry()
方法的流程还是比较简单的
- 通过哈希计算获取正确槽位的下标
- 如果有数据并且未过期,则直接返回
- 如果没有数据,则调用
getEntryAfterMiss()
方法进行线性查找
- 在调用
getEntryAfterMiss()
方法进行线性查找的过程中- 如果遇到key相等的则直接返回
- 如果遇到过期的Entry则进行一次探测式扫描
- 如果遇到未过期的Entry且key不相等则换下一个下标
直到最后遇到为null的Entry跳出循环,意味着没有对应的数据,返回null结束。
private Entry getEntry(ThreadLocal<?> key) {
int i = key.threadLocalHashCode & (table.length - 1);
Entry e = table[i];
// (1)获取到结果直接返回
if (e != null && e.get() == key)
return e;
// (2)没有则通getEntryAfterMiss查找与当前Entry相邻的其他Entry
else
return getEntryAfterMiss(key, i, e);
}
private Entry getEntryAfterMiss(ThreadLocal<?> key, int i, Entry e) {
Entry[] tab = table;
int len = tab.length;
while (e != null) {
ThreadLocal<?> k = e.get();
// (1)如果遇到key相等的则直接返回
if (k == key)
return e;
// (2)如果遇到过期的Entry则进行一次探测式扫描
if (k == null)
expungeStaleEntry(i);
// (3)如果遇到未过期的Entry且key不相等则换下一个下标
else
i = nextIndex(i, len);
e = tab[i];
}
// 直到遇到为null的Entry则意味着没找到,返回null结束
return null;
}
3.6 rehash()
在set()
方法执行的最后一步,进行启发式清除过期数据,如果未清理到任何数据则返回false,并判断size是否已经超过了扩容阈值,如果达到则执行rehash()
。rehash()
实际上还是进行一次探测式清除,并且清除完以后,会以Entry数组的size(不是length)判断是否需要进行扩容。
private void rehash() {
expungeStaleEntries();
if (size >= threshold - threshold / 4)
resize();
}
private void expungeStaleEntries() {
Entry[] tab = table;
int len = tab.length;
for (int j = 0; j < len; j++) {
Entry e = tab[j];
if (e != null && e.get() == null)
expungeStaleEntry(j);
}
}
3.7 resize()
扩容的逻辑也非常简单,将原数组扩大两倍,然后将每一个元素重新进行哈希计算将其放到新的Entry数组中。扩大两倍的目的是为了保证Entry数组的大小是2的倍数,这样才能保证i & (len - 1)
和i & len
的结果一样,前面也说过了&
的目的就是为了提高寻址效率。
private void resize() {
Entry[] oldTab = table;
int oldLen = oldTab.length;
int newLen = oldLen * 2;
Entry[] newTab = new Entry[newLen];
int count = 0;
for (int j = 0; j < oldLen; ++j) {
Entry e = oldTab[j];
if (e != null) {
ThreadLocal<?> k = e.get();
if (k == null) {
e.value = null; // Help the GC
} else {
int h = k.threadLocalHashCode & (newLen - 1);
while (newTab[h] != null)
h = nextIndex(h, newLen);
newTab[h] = e;
count++;
}
}
}
setThreshold(newLen);
size = count;
table = newTab;
}
4.更多思考
关于ThreadLocal可能导致的内存泄漏
这是一个很经典的JUC八股问题,网上泛泛而谈弱引用导致内存泄漏,但是都没有细说,我在这里讲一下吧。
在ThreadLocal中,ThreadLocalMap中的Entry弱引用于作为Key的ThreadLocal对象。像我们直接创建一个ThreadLocal类,就是ThreadLocal threadLocal = new ThreadLocal<String>();
。下面threadLocal指代变量引用,ThreadLocal指代存放在堆内存中的对象。
这里threadLocal强引用于一个ThreadLocal对象。当垃圾回收时,如果这个强引用还在那么这个ThreadLocal对象就不会被回收。
但是如果threadLocal不再强引用于这个ThreadLocal对象,也就是有关于这个threadLocal变量的代码执行完毕以后,这个时候ThreadLocal对象就只存在Entry对其的弱引用,如果发生垃圾回收那么ThreadLocal对象就会被回收。
但是这个时候Entry中的value并没有被回收,这样就可能导致内存泄漏。
之所以这样设计本身也是为了方便ThreadLocal对象的被回收行为,如果是强引用的话,那么当threadLocal变量不再使用,其原本指向的ThreadLocal对象将迟迟无法被回收。
所以每次在使用完threadlocal时要调用一下remove方法,它会自动把entry移除。
关于InheritableThreadLocal
前面我们提到了InheritableThreadLocal
是为了解决父线程的变量副本无法传递给子线程的问题背景产生的。
我们在Thread的init()
方法中也看到了,子线程创建的时候会复制父线程不为null的InheritableThreadLocal
。
但InheritableThreadLocal
仍然有缺陷,一般我们做异步化处理都是使用的线程池,而InheritableThreadLocal
是在new Thread
中的init()
方法给赋值的,而线程池是线程复用的逻辑,所以这里会存在问题。
参考文章:
- https://juejin.cn/post/6844904151567040519