DJLServing自定义模型中自定义Translator注意事项需要仔细读一下DJLServing源码中的ServingTranslatorFactory类,,一开始不了解以为DJLServing选择Translator像玄学,后来看了像迷宫一样ServingTranslatorFactory类大致明白了,以下是源码注释版,还有一个整理的流程图。
/*
* Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
import ai.djl.Application;
import ai.djl.Model;
import ai.djl.modality.Input;
import ai.djl.modality.Output;
import ai.djl.modality.cv.translator.ImageClassificationTranslator;
import ai.djl.modality.cv.translator.ImageServingTranslator;
import ai.djl.translate.*;
import ai.djl.util.ClassLoaderUtils;
import ai.djl.util.Pair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.lang.reflect.Constructor;
import java.lang.reflect.Type;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.Collections;
import java.util.Map;
import java.util.Set;
class ServingTranslatorFactory implements TranslatorFactory {
//日志打印
private static final Logger logger = LoggerFactory.getLogger(ServingTranslatorFactory.class);
//返回只有一个固定元素的SET,约束模型的输入,输出类型
@Override
public Set<Pair<Type, Type>> getSupportedTypes() {
return Collections.singleton(new Pair<>(Input.class, Output.class));
}
//工厂实例化方法
@Override
@SuppressWarnings("unchecked")
public <I, O> Translator<I, O> newInstance(
Class<I> input, Class<O> output, Model model, Map<String, ?> arguments)
throws TranslateException {
//如果输出和输出不在支持的范围内直接抛出异常
if (!isSupported(input, output)) {
throw new IllegalArgumentException("Unsupported input/output types.");
}
//获取model的路径
Path modelDir = model.getModelPath();
//获取serving.properties里面的translatorFactory参数
String factoryClass = ArgumentsUtil.stringValue(arguments, "translatorFactory");
//如果translatorFactory参数不为null且长度不为0
if (factoryClass != null && !factoryClass.isEmpty()) {
//直接加载工厂类
TranslatorFactory factory = loadTranslatorFactory(factoryClass);
//如果工厂类加载成功并且工厂类支持要去的输入输出
if (factory != null && factory.isSupported(input, output)) {
//打印日志
logger.info("Using TranslatorFactory: {}", factory.getClass().getName());
//将工厂类实例化返回
return factory.newInstance(input, output, model, arguments);
}
}
//如果上面没有匹配上
//获取serving.properties里面的translator参数
String className = (String) arguments.get("translator");
//获取model目录下的libs目录
Path libPath = modelDir.resolve("libs");
//如果这个libs目录不存在
if (!Files.isDirectory(libPath)) {
//那就找lib目录
libPath = modelDir.resolve("lib");
//如果lib目录也没有那就走loadDefaultTranslator(arguments)这个方法,加载默认的Translator
if (!Files.isDirectory(libPath) && className == null) {
return (Translator<I, O>) loadDefaultTranslator(arguments);
}
}
//如果model目录下的libs目录存在那就加载class
ServingTranslator translator = findTranslator(libPath, className);
//如果加载上了
if (translator != null) {
//设置translator的参数
translator.setArguments(arguments);
//打印日志
logger.info("Using translator: {}", translator.getClass().getName());
//直接返回translator
return (Translator<I, O>) translator;
} else if (className != null) {
//如果加载失败抛出异常
throw new TranslateException("Failed to load translator: " + className);
}
//实在是找不到就走loadDefaultTranslator(arguments)这个方法,加载默认的Translator
return (Translator<I, O>) loadDefaultTranslator(arguments);
}
private ServingTranslator findTranslator(Path path, String className) {
//找目录里面的classes目录
Path classesDir = path.resolve("classes");
//把java编译成classes
ClassLoaderUtils.compileJavaClass(classesDir);
//返回出去Translator,该类必须是ServingTranslator的实现类,因为会强制转换成ServingTranslator在子类
return ClassLoaderUtils.findImplementation(path, ServingTranslator.class, className);
}
private TranslatorFactory loadTranslatorFactory(String className) {
try {
//通过类名加载该类
Class<?> clazz = Class.forName(className);
//将该类强制转换成TranslatorFactory的子类
Class<? extends TranslatorFactory> subclass = clazz.asSubclass(TranslatorFactory.class);
//加载该类的构造方法
Constructor<? extends TranslatorFactory> constructor = subclass.getConstructor();
//构造该类返回实例
return constructor.newInstance();
} catch (Throwable e) {
//捕获异常
logger.trace("Not able to load TranslatorFactory: " + className, e);
}
return null;
}
private Translator<Input, Output> loadDefaultTranslator(Map<String, ?> arguments) {
//获取serving.properties里面的application参数
String appName = ArgumentsUtil.stringValue(arguments, "application");
//如果不为空
if (appName != null) {
Application application = Application.of(appName);
//如果是cv/image_classification
if (application == Application.CV.IMAGE_CLASSIFICATION) {
//那就加载ImageClassificationTranslator这个玩意
return getImageClassificationTranslator(arguments);
}
}
//否则的化就加载NoopServingTranslatorFactory这个玩意
NoopServingTranslatorFactory factory = new NoopServingTranslatorFactory();
//最后返回的是NoopServingTranslator这个玩意
return factory.newInstance(Input.class, Output.class, null, arguments);
}
private Translator<Input, Output> getImageClassificationTranslator(Map<String, ?> arguments) {
//返回ImageServingTranslator的实例
return new ImageServingTranslator(ImageClassificationTranslator.builder(arguments).build());
}
}