综合记录一下关于ClassLoader和Spring Bean的动态加载卸载功能
目录
- 一、需要说明
- 二、总体设计
- 三、具体设计
- 3.1 加载卸载Bean工具类
- 3.2 创建卸载方法
- 3.3 创建加载方法
- 3.4 创建获取具体服务类方法
- 四、总结
一、需要说明
- 有一个公共的发送通知的接口,这个接口需要做成单独jar,可以通过maven包引入该接口
public interface NotifyService {
void notify(MyMessage myMessage);
}
-
会有多个实现类
com.pv1.notify.MailNotifyService 邮件通知
com.pv1.notify.WxNotifyService 微信通知
com.pv1.notify.DingNotifyService 钉钉通知 -
每一个实现类为一个插件, 插件打成一个jar包上传OSS文件服务中
-
从远程文件服务中下载到本地并且注入到Spring容器中
-
支持不同版本的jar加载和卸载功能
二、总体设计
-
设计接口和公共参数对象,并且打成jar包 发布到maven仓库中
public interface NotifyService { Map<String, NotifyService > NOTIFY_SERVICE_MAP = new ConcurrentHashMap<>(32); void notify(MyMessage myMessage); }
-
创建插件信息表 jarInfo,主要是用于存放插件的注册信息,核心字段包含如下:
字段 | 说明 |
---|---|
serverFlag | 服务标识 |
jarVersion | 件版本号 |
jarUrl | 插件下载地址 |
jarServerName | 插件注册Spring服务名 |
jarClassName | ClassLoader 加载时传入的类名称 |
isController | 是否web接口服务 默认0, 1是 0否 |
isEnable | 状态 是否启用, 启用=已加载 默认0, 1是 0否 |
-
各个实现类工程从Maven仓库中引入接口包,并且实现自身业务逻辑
-
采用策略模式 创建一个Map存放各个实现类 供应用系统调用
Map<String, NotifyService > NOTIFY_SERVICE_MAP = new ConcurrentHashMap<>(32);
这里的key是 各个实现类的标识 serverFlag 字段 ,例如邮件服务的话 key为 “MAIL”
NOTIFY_SERVICE_MAP 可以作为 接口 NotifyService 的成员常量 -
创建加载方法 void loadProtocol(JarInfo jarInfo);
-
创建卸载方法 void unloadProtocol(JarInfo jarInfo);
-
创建获取具体服务类方法 NotifyService getNotifyService(String serverFlag );
总的设置原则:
表 jarInfo 里面需保证: 相同 serverFlag 的多行记录中只能有一条处于加载状态中。
三、具体设计
3.1 加载卸载Bean工具类
这里使用了hutools工具,总体上方法差不多,如有你要注册Controller 则先注册为SpringBean然后在调用注册Controller方法。
import cn.hutool.core.bean.BeanUtil;
import cn.hutool.core.bean.copier.CopyOptions;
import cn.hutool.core.collection.CollectionUtil;
import cn.hutool.core.util.StrUtil;
import cn.hutool.extra.spring.SpringUtil;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.support.BeanDefinitionBuilder;
import org.springframework.beans.factory.support.DefaultListableBeanFactory;
import org.springframework.context.ApplicationContext;
import org.springframework.util.ClassUtils;
import org.springframework.util.ReflectionUtils;
import org.springframework.web.servlet.mvc.method.RequestMappingInfo;
import org.springframework.web.servlet.mvc.method.annotation.RequestMappingHandlerMapping;
import java.lang.reflect.Method;
import java.util.List;
@Slf4j
public class ExtendBeanUtil {
public static void registerBeanDefinition(String beanName, Class<?> targetClass) {
registerBeanDefinition(beanName, targetClass, null, null, null);
}
public static void registerBeanDefinition(String beanName, Class<?> targetClass, List<Object> constructorArgs) {
registerBeanDefinition(beanName, targetClass, constructorArgs, null, null);
}
public static void registerBeanDefinition(
String beanName,
Class<?> targetClass,
String initMethodName) {
registerBeanDefinition(beanName, targetClass, null, initMethodName, null);
}
public static void registerBeanDefinition(
String beanName,
Class<?> targetClass,
String initMethodName,
String destoryMethodName) {
registerBeanDefinition(beanName, targetClass, null, initMethodName, destoryMethodName);
}
public static void registerBeanDefinition(
String beanName,
Class<?> targetClass,
List<Object> constructorArgs,
String initMethodName,
String destoryMethodName) {
ApplicationContext applicationContext = SpringUtil.getApplicationContext();
//获取BeanFactory
DefaultListableBeanFactory defaultListableBeanFactory =
(DefaultListableBeanFactory) applicationContext.getAutowireCapableBeanFactory();
//创建bean信息.
BeanDefinitionBuilder beanDefinitionBuilder = BeanDefinitionBuilder.genericBeanDefinition(targetClass);
// 如果有构造函数参数
if (CollectionUtil.isNotEmpty(constructorArgs)) {
for (Object arg : constructorArgs) {
beanDefinitionBuilder.addConstructorArgValue(arg);
}
}
// 设置 init方法
if (StrUtil.isNotBlank(initMethodName)) {
beanDefinitionBuilder.setInitMethodName(initMethodName);
}
// 设置 destory方法
if (StrUtil.isNotBlank(destoryMethodName)) {
beanDefinitionBuilder.setDestroyMethodName(destoryMethodName);
}
//动态注册bean.
defaultListableBeanFactory.registerBeanDefinition(beanName, beanDefinitionBuilder.getBeanDefinition());
}
public static void unRegisterBeanDefinition(String beanName) {
ApplicationContext applicationContext = SpringUtil.getApplicationContext();
if (!applicationContext.containsBean(beanName)) {
return;
}
//获取BeanFactory
DefaultListableBeanFactory defaultListableBeanFactory =
(DefaultListableBeanFactory) applicationContext.getAutowireCapableBeanFactory();
defaultListableBeanFactory.removeBeanDefinition(beanName);
}
public static void registerController(String controllerBeanName)
throws Exception {
final RequestMappingHandlerMapping requestMappingHandlerMapping =
SpringUtil.getBean(RequestMappingHandlerMapping.class);
if (requestMappingHandlerMapping != null) {
ApplicationContext applicationContext = SpringUtil.getApplicationContext();
if (!applicationContext.containsBean(controllerBeanName)) {
log.warn("注册Controller {} 不成功,因为在BeanFactory未找到对应的Bean信息", controllerBeanName);
return;
}
//注册Controller
Method method = requestMappingHandlerMapping.getClass().getSuperclass().getSuperclass().
getDeclaredMethod("detectHandlerMethods", Object.class);
//将private改为可使用
method.setAccessible(true);
method.invoke(requestMappingHandlerMapping, controllerBeanName);
}
}
public static void unregisterController(String controllerBeanName) {
final RequestMappingHandlerMapping requestMappingHandlerMapping
= SpringUtil.getBean("requestMappingHandlerMapping");
if (requestMappingHandlerMapping != null) {
ApplicationContext applicationContext = SpringUtil.getApplicationContext();
if (!applicationContext.containsBean(controllerBeanName)) {
return;
}
Object controller = SpringUtil.getBean(controllerBeanName);
if (controller == null) {
log.warn("卸载Controller {} 取消执行,因为在BeanFactory未找到对应的Bean信息", controllerBeanName);
return;
}
final Class<?> targetClass = controller.getClass();
ReflectionUtils.doWithMethods(targetClass, method -> {
Method specificMethod = ClassUtils.getMostSpecificMethod(method, targetClass);
try {
Method createMappingMethod = RequestMappingHandlerMapping.class.
getDeclaredMethod("getMappingForMethod", Method.class, Class.class);
createMappingMethod.setAccessible(true);
RequestMappingInfo requestMappingInfo = (RequestMappingInfo)
createMappingMethod.invoke(requestMappingHandlerMapping, specificMethod, targetClass);
if (requestMappingInfo != null) {
requestMappingHandlerMapping.unregisterMapping(requestMappingInfo);
}
} catch (Exception e) {
e.printStackTrace();
}
}, ReflectionUtils.USER_DECLARED_METHODS);
}
}
}
3.2 创建卸载方法
public void unloadProtocol(JarInfo jarInfo) {
printLog(jarInfo, "开始卸载服务", "...");
String jarServerName = jarInfo.getJarServerName();
String isController = jarInfo.getIsController();
if (YesOrNoEnum.YES.getCode().equals(isController)) {
ExtendBeanUtil.unregisterController(jarServerName);
}
ExtendBeanUtil.unRegisterBeanDefinition(jarServerName);
// 从全局策略模式 NOTIFY_SERVICE_MAP 剔除
NotifyService.NOTIFY_SERVICE_MAP.remove(jarInfo.getServerFlag());
printLog(jarInfo, "完成卸载服务", "");
}
这里的 printLog 就是简单的日志打印方法根据自己业务自行实现具体逻辑, YesOrNoEnum是个简单枚举 “1” 和 “0” 。
说明:
一般ClassLoader 已经加载的Class 建议不要想从虚拟机中卸载,这样可能导致很多异常情况。我们的web服务为Spring容器服务,我们直接从容器中卸载该服务即可。
建议:
不同版本的实现类文件放在不同包下面,例如:
com.pv1.notify.MailNotifyService
com.pv2.notify.MailNotifyService
com.pv3.notify.MailNotifyService
如果放同一包下面, 相同ClassLoader 不会重复加载相同类名 (包路径+文件名称一致) 的类文件。
3.3 创建加载方法
加载方法稍微复杂一些,可能需要从远程下载jar文件,
另外需要保证是同一个ClassLoader进行类的加载,且这个ClassLoader需实现双亲委派机制。
这里用了 hutools里面的工具类。
classLoader相关知识点参考不错的知乎文章: https://zhuanlan.zhihu.com/p/51374915
public void loadProtocol(JarInfo protocol) {
String jarClassName = protocol.getJarClassName();
String jarServerName = protocol.getJarServerName();
String isController = protocol.getIsController();
printLog(protocol, "开始加载服务", "...");
try {
// 获得一个ClassLoader
JarClassLoader jarClassLoader = getJarClassLoader(protocol);
// 先卸载
unloadProtocol(protocol);
// 加载目标类
Class<?> targetClass = jarClassLoader.loadClass(jarClassName);
printLog(protocol, "classLoader完成", "...");
// 注入到Spring容器中
ExtendBeanUtil.registerBeanDefinition(jarServerName, targetClass);
// 是否controller层接口
if (YesOrNoEnum.YES.getCode().equals(isController)) {
ExtendBeanUtil.registerController(jarServerName);
printLog(protocol, "通知类服务加载controller层接口", "...");
}
// 设置到相应的业务 MAP中,构建策略模式
afterLoadProtocol(protocol);
printLog(protocol, "完成加载服务", "");
} catch (Exception e) {
printLog(protocol, "加载服务失败", e.getMessage());
}
}
获取对应的 JarClassLoader
// 全局变量
protected final Map<String, JarClassLoader> jarClassLoaderMap = new ConcurrentHashMap<>(16);
protected JarClassLoader getJarClassLoader(JarInfo protocol) throws IOException {
String jarUrl = protocol.getJarUrl();
String baseLoaderPath = "notifyJarFiles";
JarClassLoader jarClassLoader = jarClassLoaderMap.get(jarUrl);
if (jarClassLoader == null) {
String[] jarUrlItems = jarUrl.split("/");
// 创建本地临时文件路径
File file = CreateTmpFileUtil.createTmpFile(baseLoaderPath, jarUrlItems[jarUrlItems.length - 1]);
// 本地文件不存在从远程下载
if (!file.exists()) {
OutputStream outputStream = Files.newOutputStream(file.toPath());
// ossService 为文件服务下载工具类
ossService.downloadFileToOutputStream(jarUrl, outputStream);
printLog(protocol, "从远程服务器下载jar文件完成", "...");
}
printLog(protocol, "jar文件地址:" + file.getAbsolutePath(), "...");
jarClassLoader = ClassLoaderUtil.getJarClassLoader(file);
// 保存ClassLoader 下次再用
jarClassLoaderMap.put(jarUrl, jarClassLoader);
}
return jarClassLoader;
}
设置到相应的业务 MAP中,构建策略模式
protected void afterLoadProtocol(JarInfo protocol) {
String jarType = protocol.getJarType();
String jarServerName = protocol.getJarServerName();
NotifyService notifyService = SpringUtil.getBean(jarServerName, NotifyService.class);
// 添加到 全局策略模式 NOTIFY_SERVICE_MAP 中
NotifyService.NOTIFY_SERVICE_MAP.put(jarInfo.getServerFlag(), notifyService);
}
3.4 创建获取具体服务类方法
public NotifyService getNotifyService(String serverFlag ) {
// 从最新缓存或数据库中加载jar信息,根据服务标识 serverFlag
JarInfo jarInfo = getEnableJarInfo(serverFlag);
String jarClassName = jarInfo.getJarClassName();
NotifyService service = NOTIFY_SERVICE_MAP.get(serverFlag);
if (service != null) {
// 检验是否和当前的协议转换层配置信息一致
String existsClassName = service.getClass().getName();
if (!existsClassName.equals(jarClassName)) {
// 表中最新的jar信息和目前缓存的不一致
// 重新加载
loadProtocol(jarInfo);
}
} else {
// 重新加载
loadProtocol(jarInfo);
}
// 重新取一次
return NotifyService.NOTIFY_SERVICE_MAP.get(serverFlag );
}
重点说明:
getEnableJarInfo 方法是根据 serverFlag 取表jarInfo 中 enable=1的 唯一一条数据。
如果担心 getEnableJarInfo 每次都要重表里面取导致应用性能有问题,则建议先从缓存中取取不到从数据库中取,但确保缓存中和数据库中数据一致。
验证获取到的服务类的版本是否和表中一致 :
if (!existsClassName.equals(jarClassName)) 这行代码的理由是 一开始规定了 不同版本的Class文件对应的包路径也不一样,
因此jarClassName如果不相同则表示当前Spring容器中的服务类需要卸载然后重新加载。
四、总结
以上设计代码主要示意为主,真正用于生产环境还需进一步优化,总体上各功能都已经实现。
Spring应用启动时候需要根据表中的配置信息进行初始化操作。
建议实现 CommandLineRunner 接口,在 public void run(String… args) 方法中进行初始化:
查询所有启用中的jarInfo记录,然后调用 loadProtocol 方法