背景:为了获取相关字段方便,项目里使用了TransmittableThreadLocal上下文,在异步逻辑中get值时发现并非当前请求的值,且是偶发状况(并发问题)。
发现:TransmittableThreadLocal是阿里开源的可以实现父子线程值传递的工具,其子线程必须使用TtlRunnable\TtlCallable修饰或者线程池使用TtlExecutors修饰(防止数据“污染”),如果没有使用装饰后的线程池,那么使用TransmittableThreadLocal上下文,就有可能出现线程不安全的问题。
参考代码:
封装的上下文,成员变量RequestHeader
package org.example.ttl.threadLocal;
import com.alibaba.ttl.TransmittableThreadLocal;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.ToString;
import org.apache.commons.lang3.ObjectUtils;
/**
* description:
* author: JohnsonLiu
* create at: 2021/12/24 23:19
*/
@Data
@AllArgsConstructor
@NoArgsConstructor
public class RequestContext {
private static final ThreadLocal<RequestContext> transmittableThreadLocal = new TransmittableThreadLocal();
private static final RequestContext INSTANCE = new RequestContext();
private RequestHeader requestHeader;
public static void create(RequestHeader requestHeader) {
transmittableThreadLocal.set(new RequestContext(requestHeader));
}
public static RequestContext current() {
return ObjectUtils.defaultIfNull(transmittableThreadLocal.get(), INSTANCE);
}
public static RequestHeader get() {
return current().getRequestHeader();
}
public static void remove() {
transmittableThreadLocal.set(null);
}
@Data
@AllArgsConstructor
@NoArgsConstructor
@ToString
static class RequestHeader {
private String requestUrl;
private String requestType;
}
}
获取上下文内容的case:
package org.example.ttl.threadLocal;
import com.alibaba.ttl.threadpool.TtlExecutors;
import java.util.concurrent.Executor;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
/**
* description: TransmittableThreadLocal正确使用
* author: JohnsonLiu
* create at: 2021/12/24 22:24
* 验证结论:
* 1.线程池必须使用TtlExecutors修饰,或者Runnable\Callable必须使用TtlRunnable\TtlCallable修饰
* ---->原因:子线程复用,子线程拥有的上下文内容会对下次使用造成“污染”,而修饰后的子线程在执行run方法后会进行“回放”,防止污染
*/
public class TransmittableThreadLocalCase2 {
// 为达到线程100%复用便于测试,线程池核心数1
private static final Executor TTL_EXECUTOR = TtlExecutors.getTtlExecutor(new ThreadPoolExecutor(1, 1, 1000, TimeUnit.MICROSECONDS, new LinkedBlockingQueue<>(1000)));
// 如果使用一般的线程池或者Runnable\Callable时,会存在线程“污染”,比如线程池中线程会复用,复用的线程会“污染”该线程执行下一次任务
private static final Executor EXECUTOR = new ThreadPoolExecutor(1, 1, 1000, TimeUnit.MICROSECONDS, new LinkedBlockingQueue<>(1000));
public static void main(String[] args) {
RequestContext.create(new RequestContext.RequestHeader("url", "get"));
System.out.println(Thread.currentThread().getName() + " 子线程(rm之前 同步):" + RequestContext.get());
// 模拟另一个线程修改上下文内容
EXECUTOR.execute(() -> {
RequestContext.create(new RequestContext.RequestHeader("url", "put"));
});
// 保证上面子线程修改成功
try {
Thread.sleep(2000);
} catch (InterruptedException e) {
e.printStackTrace();
}
// 异步获取上下文内容
TTL_EXECUTOR.execute(() -> {
try {
Thread.sleep(2000);
} catch (InterruptedException e) {
e.printStackTrace();
}
System.out.println(Thread.currentThread().getName() + " 子线程(rm之前 异步):" + RequestContext.get());
});
// 主线程修改上下文内容
RequestContext.create(new RequestContext.RequestHeader("url", "post"));
System.out.println(Thread.currentThread().getName() + " 子线程(rm之前 同步<reCreate>):" + RequestContext.get());
// 主线程remove
RequestContext.remove();
// 子线程获取remove后的上下文内容
TTL_EXECUTOR.execute(() -> {
try {
Thread.sleep(3000);
} catch (InterruptedException e) {
e.printStackTrace();
}
System.out.println(Thread.currentThread().getName() + " 子线程(rm之后 异步):" + RequestContext.get());
});
}
}
使用一般线程池结果:
使用修饰后的线程池结果:
这种问题的解决办法:
如果大家跟我一样存在这样的使用,那么也会低概率存在这样的问题,正确的使用方式是:
子线程必须使用TtlRunnable\TtlCallable修饰或者线程池使用TtlExecutors修饰,这一点很容易被遗漏,比如上下文和异步逻辑不是同一个人开发的,那么异步逻辑的开发者就很可能直接在异步逻辑中使用上下文,而忽略装饰线程池,造成线程复用时的“数据污染”。
另外还有一种不同于上面的上下文用法,同样使用不当也会存在线程安全问题:
上代码样例
package org.example.ttl.threadLocal;
import com.alibaba.ttl.TransmittableThreadLocal;
import java.util.LinkedHashMap;
import java.util.Map;
/**
* description: TransmittableThreadLocal正确使用
* author: JohnsonLiu
* create at: 2021/12/24 23:19
*/
public class ServiceContext {
private static final ThreadLocal<Map<Integer, Integer>> transmittableThreadLocal = new TransmittableThreadLocal() {
/**
* 如果使用的是TtlExecutors装饰的线程池或者TtlRunnable、TtlCallable装饰的任务
* 重写copy方法且重新赋值给新的LinkedHashMap,不然会导致父子线程都是持有同一个引用,只要有修改取值都会变化。引用值线程不安全
* parentValue是父线程执行子任务那个时刻的快照值,后续父线程再次set值也不会影响子线程get,因为已经不是同一个引用
* @param parentValue
* @return
*/
@Override
public Object copy(Object parentValue) {
if (parentValue instanceof Map) {
System.out.println("copy");
return new LinkedHashMap<Integer, Integer>((Map) parentValue);
}
return null;
}
/**
* 如果使用普通线程池执行异步任务,重写childValue即可实现子线程获取的是父线程执行任务那个时刻的快照值,重新赋值给新的LinkedHashMap,父线程修改不会影响子线程(非共享)
* 但是如果使用的是TtlExecutors装饰的线程池或者TtlRunnable、TtlCallable装饰的任务,此时就会变成引用共享,必须得重写copy方法才能实现非共享
* @param parentValue
* @return
*/
@Override
protected Object childValue(Object parentValue) {
if (parentValue instanceof Map) {
System.out.println("childValue");
return new LinkedHashMap<Integer, Integer>((Map) parentValue);
}
return null;
}
/**
* 初始化,每次get时都会进行初始化
* @return
*/
@Override
protected Object initialValue() {
System.out.println("initialValue");
return new LinkedHashMap<Integer, Integer>();
}
};
public static void set(Integer key, Integer value) {
transmittableThreadLocal.get().put(key, value);
}
public static Map<Integer, Integer> get() {
return transmittableThreadLocal.get();
}
public static void remove() {
transmittableThreadLocal.remove();
}
}
使用case:
package org.example.ttl.threadLocal;
import com.alibaba.ttl.threadpool.TtlExecutors;
import java.util.concurrent.Executor;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
/**
* description: TransmittableThreadLocal正确使用
* author: JohnsonLiu
* create at: 2021/12/24 22:24
*/
public class TransmittableThreadLocalCase {
private static final Executor executor = TtlExecutors.getTtlExecutor(new ThreadPoolExecutor(1, 1, 1000, TimeUnit.MICROSECONDS, new LinkedBlockingQueue<>(1000)));
// private static final Executor executor = new ThreadPoolExecutor(1, 1, 1000, TimeUnit.MICROSECONDS, new LinkedBlockingQueue<>(1000));
static int i = 0;
public static void main(String[] args) {
ServiceContext.set(++i, i);
executor.execute(() -> {
try {
Thread.sleep(3000);
} catch (InterruptedException e) {
e.printStackTrace();
}
System.out.println(Thread.currentThread().getName() + " 子线程(rm之前):" + ServiceContext.get());
});
ServiceContext.set(++i, i);
ServiceContext.remove();
executor.execute(() -> {
try {
Thread.sleep(3000);
} catch (InterruptedException e) {
e.printStackTrace();
}
System.out.println(Thread.currentThread().getName() + " 子线程(rm之后):" + ServiceContext.get());
});
}
}