1. Python部分:导出ONNX模型
首先,我们需要在Python中定义并导出一个已经训练好的验证码识别模型。以下是完整的Python代码:
import string
import torch
import torch.nn as nn
import torch.nn.functional as F
CHAR_SET = string.digits
# 优化后的模型设计
class CaptchaModel(nn.Module):
def __init__(self):
super(CaptchaModel, self).__init__()
self.conv1 = nn.Conv2d(1, 32, 3, padding=1)
self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
self.conv3 = nn.Conv2d(64, 128, 3, padding=1)
self.fc1 = nn.Linear(128 * 5 * 12, 256) # 调整为实际展平维度
self.fc2 = nn.Linear(256, 4 * len(CHAR_SET))
self.dropout = nn.Dropout(0.5)
def forward(self, x):
x = F.relu(F.max_pool2d(self.conv1(x), 2))
x = F.relu(F.max_pool2d(self.conv2(x), 2))
x = F.relu(F.max_pool2d(self.conv3(x), 2))
x = x.view(x.size(0), -1)
x = F.relu(self.fc1(x))
x = self.dropout(x)
x = self.fc2(x)
return x.view(-1, 4, len(CHAR_SET))
# 使用CUDA,如果可用的话
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# 假设你的模型已经训练好并保存在 'best_model.pth'
model = CaptchaModel().to(device)
model.load_state_dict(torch.load('best_model.pth'))
# 生成一个测试输入 (示例输入的形状应与模型输入形状一致)
dummy_input = torch.randn(1, 1, 40, 100).to(device)
# 导出模型为 ONNX 格式
torch.onnx.export(model, dummy_input, "captcha_model.onnx",
input_names=["input"], output_names=["output"],
dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}})
print("Model exported to captcha_model.onnx")
这段代码定义了一个验证码识别模型,并将其导出为ONNX格式,以便在Java中使用。
2. Java部分:调用ONNX模型进行验证码识别
接下来,我们使用Java调用导出的ONNX模型进行验证码识别。以下是完整的Java代码:
- 引用onnxruntime-1.19.0.jar
package com.tushuoit;
import ai.onnxruntime.*;
import javax.imageio.ImageIO;
import java.awt.*;
import java.awt.image.BufferedImage;
import java.io.File;
import java.nio.FloatBuffer;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Random;
import java.util.List;
public class CaptchaInference {
private static final String CHAR_SET = "0123456789";
private static final int INPUT_WIDTH = 100;
private static final int INPUT_HEIGHT = 40;
private static final Random random = new Random();
public static void main(String[] args) throws Exception {
// 随机生成4个字符的验证码文本
String captchaText = generateRandomText(4);
System.out.println("Generated Captcha Text: " + captchaText);
// 生成包含文本的Bitmap (BufferedImage)
BufferedImage captchaImage = generateCaptcha(captchaText, 36, INPUT_WIDTH, INPUT_HEIGHT);
// 将Bitmap保存为文件(仅用于查看生成的图像,实际使用中可以省略)
ImageIO.write(captchaImage, "png", new File("generated_captcha.png"));
// 将图像转换为浮点数数组,并进行归一化处理
float[] inputData = imageToFloatArray(captchaImage);
// 创建ONNX Runtime环境
OrtEnvironment env = OrtEnvironment.getEnvironment();
OrtSession.SessionOptions opts = new OrtSession.SessionOptions();
// 加载ONNX模型
OrtSession session = env.createSession("captcha_model.onnx", opts);
// 创建输入张量
FloatBuffer inputBuffer = FloatBuffer.wrap(inputData);
OnnxTensor inputTensor = OnnxTensor.createTensor(env, inputBuffer,
new long[] { 1, 1, INPUT_HEIGHT, INPUT_WIDTH });
// 进行推理
OrtSession.Result result = session.run(Collections.singletonMap("input", inputTensor));
// Extract output tensor and decode it
float[][][] outputData = (float[][][]) result.get(0).getValue();
List<String> decodedTexts = decodeOutput(outputData);
// Print the decoded captcha text
for (String text : decodedTexts) {
System.out.println("Predicted Captcha Text: " + text);
}
System.out.println("Inference completed.");
// 释放资源
session.close();
env.close();
}
// 随机生成指定长度的验证码文本
private static String generateRandomText(int length) {
StringBuilder text = new StringBuilder(length);
for (int i = 0; i < length; i++) {
text.append(CHAR_SET.charAt(random.nextInt(CHAR_SET.length())));
}
return text.toString();
}
// 生成包含文本的BufferedImage
private static BufferedImage generateCaptcha(String text, int fontSize, int width, int height) {
BufferedImage image = new BufferedImage(width, height, BufferedImage.TYPE_INT_RGB);
Graphics2D g2d = image.createGraphics();
// 设置背景颜色为白色
g2d.setColor(Color.WHITE);
g2d.fillRect(0, 0, width, height);
// 设置字体和颜色
g2d.setFont(new Font("DroidSansMono", Font.PLAIN, fontSize));
g2d.setColor(Color.BLACK);
// 绘制文本
FontMetrics fm = g2d.getFontMetrics();
int x = 5; // 文字开始的X坐标
int y = fm.getAscent() + 5; // 文字开始的Y坐标
g2d.drawString(text, x, y);
g2d.dispose();
return image;
}
// 将BufferedImage转换为float数组,并进行归一化处理
private static float[] imageToFloatArray(BufferedImage image) {
int width = image.getWidth();
int height = image.getHeight();
float[] floatArray = new float[width * height];
for (int y = 0; y < height; y++) {
for (int x = 0; x < width; x++) {
int rgb = image.getRGB(x, y);
int gray = (rgb >> 16) & 0xFF; // 因为是灰度图,只需获取一个通道的值
floatArray[y * width + x] = (gray / 255.0f - 0.5f) * 2.0f; // 归一化到[-1, 1]
}
}
return floatArray;
}
private static List<String> decodeOutput(float[][][] outputData) {
List<String> decodedTexts = new ArrayList<>();
for (float[][] singleOutput : outputData) {
StringBuilder decodedText = new StringBuilder();
for (float[] charProbabilities : singleOutput) {
int maxIndex = getMaxIndex(charProbabilities);
decodedText.append(CHAR_SET.charAt(maxIndex));
}
decodedTexts.add(decodedText.toString());
}
return decodedTexts;
}
private static int getMaxIndex(float[] probabilities) {
int maxIndex = 0;
float maxProb = probabilities[0];
for (int i = 1; i < probabilities.length; i++) {
if (probabilities[i] > maxProb) {
maxProb = probabilities[i];
maxIndex = i;
}
}
return maxIndex;
}
}
这段Java代码首先生成一个随机的验证码图像,然后将其转换为模型输入格式,并通过ONNX Runtime调用导出的模型进行推理,最后解码模型的输出以获取识别的验证码文本。
总结
通过上述步骤,我们成功地在Python中导出了一个验证码识别模型,并在Java中调用该模型进行验证码识别。这种方法充分利用了Python在深度学习模型训练和导出方面的优势,以及Java在实际应用部署和性能方面的优势,实现了高效的验证码识别系统。