Milvus向量库和SeetaSDK工具类分享
- 1.Milvus向量库工具类
- 2.SeetaSDK工具类
1.Milvus向量库工具类
Milvus的Maven依赖:
<dependency>
<groupId>io.milvus</groupId>
<artifactId>milvus-sdk-java</artifactId>
<version>2.1.0</version>
<exclusions>
<exclusion>
<artifactId>log4j-slf4j-impl</artifactId>
<groupId>org.apache.logging.log4j</groupId>
</exclusion>
</exclusions>
</dependency>
向量库的配置类:
@Data
@Component
@ConfigurationProperties(MilvusConfiguration.PREFIX)
public class MilvusConfiguration {
public static final String PREFIX = "milvus-config";
public String host;
public int port;
public String collectionName;
}
工具类主类:
@Slf4j
@Component
public class MilvusUtil {
@Resource
private MilvusConfiguration milvusConfiguration;
private MilvusServiceClient milvusServiceClient;
@PostConstruct
private void connectToServer() {
milvusServiceClient = new MilvusServiceClient(
ConnectParam.newBuilder()
.withHost(milvusConfiguration.host)
.withPort(milvusConfiguration.port)
.build());
// 加载数据
LoadCollectionParam faceSearchNewLoad = LoadCollectionParam.newBuilder()
.withCollectionName(milvusConfiguration.collectionName).build();
R<RpcStatus> rpcStatusR = milvusServiceClient.loadCollection(faceSearchNewLoad);
log.info("Milvus LoadCollection [{}]", rpcStatusR.getStatus() == 0 ? "Successful!" : "Failed!");
}
}
主类里的数据入库方法:
public int insertDataToMilvus(String id, String path, float[] feature) {
List<InsertParam.Field> fields = new ArrayList<>();
List<Float> featureList = new ArrayList<>(feature.length);
for (float v : feature) {
featureList.add(v);
}
fields.add(new InsertParam.Field("id", Collections.singletonList(id)));
fields.add(new InsertParam.Field("image_path", Collections.singletonList(path)));
fields.add(new InsertParam.Field("image_feature", Collections.singletonList(featureList)));
InsertParam insertParam = InsertParam.newBuilder()
.withCollectionName(milvusConfiguration.collectionName)
//.withPartitionName("novel")
.withFields(fields)
.build();
R<MutationResult> insert = milvusServiceClient.insert(insertParam);
return insert.getStatus();
}
主类类的数据查询方法:
- 这里的topK没有进行参数化。
public List<MilvusRes> searchImageByFeatureVector(float[] feature) {
List<Float> featureList = new ArrayList<>(feature.length);
for (float v : feature) {
featureList.add(v);
}
List<String> queryOutputFields = Arrays.asList("image_path");
SearchParam faceSearch = SearchParam.newBuilder()
.withCollectionName(milvusConfiguration.collectionName)
.withMetricType(MetricType.IP)
.withVectorFieldName("image_feature")
.withVectors(Collections.singletonList(featureList))
.withOutFields(queryOutputFields)
.withRoundDecimal(3)
.withTopK(10).build();
// 执行搜索
long l = System.currentTimeMillis();
R<SearchResults> respSearch = milvusServiceClient.search(faceSearch);
log.info("MilvusServiceClient.search cost [{}]", System.currentTimeMillis() - l);
// 解析结果数据
SearchResultData results = respSearch.getData().getResults();
int scoresCount = results.getScoresCount();
SearchResultsWrapper wrapperSearch = new SearchResultsWrapper(results);
List<MilvusRes> milvusResList = new ArrayList<>();
for (int i = 0; i < scoresCount; i++) {
float score = wrapperSearch.getIDScore(0).get(i).getScore();
Object imagePath = wrapperSearch.getFieldData("image_path", 0).get(i);
MilvusRes milvusRes = MilvusRes.builder().score(score).imagePath(imagePath.toString()).build();
milvusResList.add(milvusRes);
}
return milvusResList;
}
2.SeetaSDK工具类
SeetaSDK的Maven依赖:
<dependency>
<groupId>com.seeta</groupId>
<artifactId>sdk</artifactId>
<version>1.2.1</version>
<scope>system</scope>
<systemPath>${project.basedir}/lib/seeta-sdk-platform-1.2.1.jar</systemPath>
</dependency>
<!--注意-->
<plugin>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-maven-plugin</artifactId>
<configuration>
<includeSystemScope>true</includeSystemScope>
</configuration>
</plugin>
jar是从官网下的源码进行的打包:
工具类主类:
@Slf4j
@Component
public class FaceUtil {
static {
// 加载本地方法
LoadNativeCore.LOAD_NATIVE(SeetaDevice.SEETA_DEVICE_AUTO);
}
@Resource
private SeetaModelConfiguration seetaModelConfiguration;
private FaceDetectorProxy faceDetectorProxy;
private FaceLandmarkerProxy faceLandmarkerProxy;
private FaceRecognizerProxy faceRecognizerProxy;
private AgePredictorProxy agePredictorProxy;
private GenderPredictorProxy genderPredictorProxy;
private MaskDetectorProxy maskDetectorProxy;
private EyeStateDetectorProxy eyeStateDetectorProxy;
}
主类里的初始方法:
@PostConstruct
private void inti() {
String basePath = seetaModelConfiguration.basePath;
try {
// 人脸识别检测器对象池配置
SeetaConfSetting detectorPoolSetting = new SeetaConfSetting(
new SeetaModelSetting(0, new String[]{basePath + seetaModelConfiguration.faceDetectorModelFileName},
SeetaDevice.SEETA_DEVICE_AUTO));
faceDetectorProxy = new FaceDetectorProxy(detectorPoolSetting);
// 关键点定位器【默认使用5点可通过配置切换为68点】
SeetaConfSetting faceLandmarkerPoolSetting = new SeetaConfSetting(
new SeetaModelSetting(1, new String[]{basePath + seetaModelConfiguration.faceLandmarkerModelFileName},
SeetaDevice.SEETA_DEVICE_AUTO));
faceLandmarkerProxy = new FaceLandmarkerProxy(faceLandmarkerPoolSetting);
// 人脸向量特征提取和对比器
SeetaConfSetting faceRecognizerPoolSetting = new SeetaConfSetting(
new SeetaModelSetting(2, new String[]{basePath + seetaModelConfiguration.faceRecognizerModelFileName},
SeetaDevice.SEETA_DEVICE_AUTO));
faceRecognizerProxy = new FaceRecognizerProxy(faceRecognizerPoolSetting);
// 年龄评估器
SeetaConfSetting agePredictorPoolSetting = new SeetaConfSetting(
new SeetaModelSetting(3, new String[]{basePath + seetaModelConfiguration.agePredictorModelFileName},
SeetaDevice.SEETA_DEVICE_AUTO));
agePredictorProxy = new AgePredictorProxy(agePredictorPoolSetting);
// 性别识别器
SeetaConfSetting genderPredictorPoolSetting = new SeetaConfSetting(
new SeetaModelSetting(4, new String[]{basePath + seetaModelConfiguration.genderPredictorModelFileName},
SeetaDevice.SEETA_DEVICE_AUTO));
genderPredictorProxy = new GenderPredictorProxy(genderPredictorPoolSetting);
// 口罩检测器
SeetaConfSetting maskDetectorPoolSetting = new SeetaConfSetting(
new SeetaModelSetting(5, new String[]{basePath + seetaModelConfiguration.maskDetectorModelFileName},
SeetaDevice.SEETA_DEVICE_AUTO));
maskDetectorProxy = new MaskDetectorProxy(maskDetectorPoolSetting);
// 眼睛状态检测
SeetaConfSetting eyeStaterPoolSetting = new SeetaConfSetting(
new SeetaModelSetting(5, new String[]{basePath + seetaModelConfiguration.eyeStateModelFileName},
SeetaDevice.SEETA_DEVICE_AUTO));
eyeStateDetectorProxy = new EyeStateDetectorProxy(eyeStaterPoolSetting);
} catch (Exception e) {
e.printStackTrace();
}
}
主类里的根据图片路径获取脸部特征向量方法:
/**
* 根据图片路径获取脸部特征向量
*
* @param imagePath 图片路径
* @return 脸部特征向量
*/
public float[] getFaceFeaturesByPath(String imagePath) {
try {
// 照片人脸识别
SeetaImageData image = SeetafaceUtil.toSeetaImageData(imagePath);
SeetaRect[] detects = faceDetectorProxy.detect(image);
// 人脸关键点定位【主驾或副驾仅有一个人脸,多个人脸仅取第一个】
if (detects.length > 0) {
SeetaPointF[] pointFace = faceLandmarkerProxy.mark(image, detects[0]);
// 人脸向量特征提取features
return faceRecognizerProxy.extract(image, pointFace);
}
} catch (Exception e) {
e.printStackTrace();
}
return null;
}
主类里的根据人像图片的路径获取其属性【年龄、性别、是否戴口罩、眼睛状态】方法:
/**
* 根据人像图片的路径获取其属性【年龄、性别、是否戴口罩、眼睛状态】
*
* @param imagePath 图片路径
* @return 图片属性 MAP 对象
*/
public Map<String, Object> getAttributeByPath(String imagePath) {
long l = System.currentTimeMillis();
Map<String, Object> attributeMap = new HashMap<>(4);
try {
// 监测人脸
SeetaImageData image = SeetafaceUtil.toSeetaImageData(imagePath);
SeetaRect[] detects = faceDetectorProxy.detect(image);
if (detects.length > 0) {
SeetaPointF[] pointFace = faceLandmarkerProxy.mark(image, detects[0]);
// 获取年龄
int age = agePredictorProxy.predictAgeWithCrop(image, pointFace);
attributeMap.put("age", age);
// 性别
GenderPredictor.GENDER gender = genderPredictorProxy.predictGenderWithCrop(image, pointFace).getGender();
attributeMap.put("gender", gender);
// 口罩
boolean mask = maskDetectorProxy.detect(image, detects[0]).getMask();
attributeMap.put("mask", mask);
// 眼睛
EyeStateDetector.EYE_STATE[] eyeStates = eyeStateDetectorProxy.detect(image, pointFace);
attributeMap.put("eye", Arrays.toString(eyeStates));
log.info("getAttributeByPath [{}] cost [{}]", imagePath, System.currentTimeMillis() - l);
}
} catch (Exception e) {
e.printStackTrace();
return attributeMap;
}
return attributeMap;
}