Spring Boot集成DeepLearning4j实现图片数字识别

news2024/11/17 7:21:51

1.什么是DeepLearning4j?

DeepLearning4J(DL4J)是一套基于Java语言的神经网络工具包,可以构建、定型和部署神经网络。DL4J与Hadoop和Spark集成,支持分布式CPU和GPU,为商业环境(而非研究工具目的)所设计。Skymind是DL4J的商业支持机构。 Deeplearning4j拥有先进的技术,以即插即用为目标,通过更多预设的使用,避免多余的配置,让非企业也能够进行快速的原型制作。DL4J同时可以规模化定制。DL4J遵循Apache 2.0许可协议,一切以其为基础的衍生作品均属于衍生作品的作

Deeplearning4j的功能

Deeplearning4j包括了分布式、多线程的深度学习框架,以及普通的单线程深度学习框架。定型过程以集群进行,也就是说,Deeplearning4j可以快速处理大量数据。神经网络可通过[迭代化简]平行定型,与 Java、 Scala 和 Clojure 均兼容。Deeplearning4j在开放堆栈中作为模块组件的功能,使之成为首个为微服务架构打造的深度学习框架。

​​​​​​​

Deeplearning4j的组件

深度神经网络能够实现前所未有的准确度。对神经网络的简介请参见概览页。简而言之,Deeplearning4j能够让你从各类浅层网络(其中每一层在英文中被称为layer)出发,设计深层神经网络。这一灵活性使用户可以根据所需,在分布式、生产级、能够在分布式CPU或GPU的基础上与Spark和Hadoop协同工作的框架内,整合受限玻尔兹曼机、其他自动编码器、卷积网络或递归网络。 此处为我们已经建立的各个库及其在系统整体中的所处位置:  

dl4j-ecosystem-cn-small

DeepLearning4J用于设计神经网络:

  • Deeplearning4j(简称DL4J)是为Java和Scala编写的首个商业级开源分布式深度学习
  • DL4J与Hadoop和Spark集成,为商业环境(而非研究工具目的)所设计。
  • 支持GPU和CPU
  • 受到 Cloudera, Hortonwork, NVIDIA, Intel, IBM 等认证,可以在Spark, Flink, Hadoop 上运行
  • 支持并行迭代算法架构
  • DeepLearning4J的JavaDoc可在此处获取
  • DeepLearning4J示例的Github代码库请见此处。相关示例的简介汇总请见此处。
  • 开源工具 ASF 2.0许可证:github.com/deeplearning4j/deeplearning4j

2.训练模型

训练和测试数据集下载

https://raw.githubusercontent.com/zq2599/blog_download_files/master/files/mnist_png.tar.gz

MNIST简介

  • MNIST是经典的计算机视觉数据集,来源是National Institute of Standards and Technology (NIST,美国国家标准与技术研究所),包含各种手写数字图片,其中训练集60,000张,测试集 10,000张,
  • MNIST来源于250 个不同人的手写,其中 50% 是高中学生, 50% 来自人口普查局 (the Census Bureau) 的工作人员.,测试集(test set) 也是同样比例的手写数字数据
  • MNIST官网:http://yann.lecun.com/exdb/mnist/

数据集简介

从MNIST官网下载的原始数据并非图片文件,需要按官方给出的格式说明做解析处理才能转为一张张图片,这些事情显然不是本篇的主题,因此咱们可以直接使用DL4J为我们准备好的数据集(下载地址稍后给出),该数据集中是一张张独立的图片,这些图片所在目录的名字就是该图片具体的数字

模型训练

LeNet-5简介

1351564-20180827204056354-1429986291

LeNet-5 结构:
  • 输入层

图片大小为 32×32×1,其中 1 表示为黑白图像,只有一个 channel。

  • 卷积层

filter 大小 5×5,filter 深度(个数)为 6,padding 为 0, 卷积步长 s=1=1,输出矩阵大小为 28×28×6,其中 6 表示 filter 的个数。

  • 池化层

average pooling,filter 大小 2×2(即 f=2=2),步长 s=2=2,no padding,输出矩阵大小为 14×14×6。

  • 卷积层

filter 大小 5×5,filter 个数为 16,padding 为 0, 卷积步长 s=1=1,输出矩阵大小为 10×10×16,其中 16 表示 filter 的个数。

  • 池化层

average pooling,filter 大小 2×2(即 f=2=2),步长 s=2=2,no padding,输出矩阵大小为 5×5×16。注意,在该层结束,需要将 5×5×16 的矩阵flatten 成一个 400 维的向量。

  • 全连接层(Fully Connected layer,FC)

neuron 数量为 120。

  • 全连接层(Fully Connected layer,FC)

neuron 数量为 84。

  • 全连接层,输出层

现在版本的 LeNet-5 输出层一般会采用 softmax 激活函数,在 LeNet-5 提出的论文中使用的激活函数不是 softmax,但其现在不常用。该层神经元数量为 10,代表 0~9 十个数字类别。(图 1 其实少画了一个表示全连接层的方框,而直接用 ^y^ 表示输出层。)  

/*******************************************************************************
 * Copyright (c) 2020 Konduit K.K.
 * Copyright (c) 2015-2019 Skymind, Inc.
 *
 * This program and the accompanying materials are made available under the
 * terms of the Apache License, Version 2.0 which is available at
 * https://www.apache.org/licenses/LICENSE-2.0.
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations
 * under the License.
 *
 * SPDX-License-Identifier: Apache-2.0
 ******************************************************************************/

package com.et.dl4j.model;

import lombok.extern.slf4j.Slf4j;
import org.datavec.api.io.labels.ParentPathLabelGenerator;
import org.datavec.api.split.FileSplit;
import org.datavec.image.loader.NativeImageLoader;
import org.datavec.image.recordreader.ImageRecordReader;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.conf.layers.SubsamplingLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization;
import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler;
import org.nd4j.linalg.learning.config.Nesterovs;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.nd4j.linalg.schedule.MapSchedule;
import org.nd4j.linalg.schedule.ScheduleType;

import java.io.File;
import java.util.HashMap;
import java.util.Map;
import java.util.Random;

/**
 * Implementation of LeNet-5 for handwritten digits image classification on MNIST dataset (99% accuracy)
 * <a href="http://yann.lecun.com/exdb/publis/pdf/lecun-01a.pdf">[LeCun et al., 1998. Gradient based learning applied to document recognition]</a>
 * Some minor changes are made to the architecture like using ReLU and identity activation instead of
 * sigmoid/tanh, max pooling instead of avg pooling and softmax output layer.
 * <p>
 * This example will download 15 Mb of data on the first run.
 *
 * @author hanlon
 * @author agibsonccc
 * @author fvaleri
 * @author dariuszzbyrad
 */
@Slf4j
public class LeNetMNISTReLu {
   //dataset github:https://raw.githubusercontent.com/zq2599/blog_download_files/master/files/mnist_png.tar.gz
    // 存放文件的地址,请酌情修改
//    private static final String BASE_PATH = System.getProperty("java.io.tmpdir") + "/mnist";
    private static final String BASE_PATH = "/Users/liuhaihua/Downloads";

    public static void main(String[] args) throws Exception {
        // 图片像素高
        int height = 28;
        // 图片像素宽
        int width = 28;

        // 因为是黑白图像,所以颜色通道只有一个
        int channels = 1;

        // 分类结果,0-9,共十种数字
        int outputNum = 10;

        // 批大小
        int batchSize = 54;

        // 循环次数
        int nEpochs = 1;

        // 初始化伪随机数的种子
        int seed = 1234;

        // 随机数工具
        Random randNumGen = new Random(seed);

        log.info("检查数据集文件夹是否存在:{}", BASE_PATH + "/mnist_png");

        if (!new File(BASE_PATH + "/mnist_png").exists()) {
            log.info("数据集文件不存在,请下载压缩包并解压到:{}", BASE_PATH);
            return;
        }

        // 标签生成器,将指定文件的父目录作为标签
        ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator();
        // 归一化配置(像素值从0-255变为0-1)
        DataNormalization imageScaler = new ImagePreProcessingScaler();

        // 不论训练集还是测试集,初始化操作都是相同套路:
        // 1. 读取图片,数据格式为NCHW
        // 2. 根据批大小创建的迭代器
        // 3. 将归一化器作为预处理器

        log.info("训练集的矢量化操作...");
        // 初始化训练集
        File trainData = new File(BASE_PATH + "/mnist_png/training");
        FileSplit trainSplit = new FileSplit(trainData, NativeImageLoader.ALLOWED_FORMATS, randNumGen);
        ImageRecordReader trainRR = new ImageRecordReader(height, width, channels, labelMaker);
        trainRR.initialize(trainSplit);
        DataSetIterator trainIter = new RecordReaderDataSetIterator(trainRR, batchSize, 1, outputNum);
        // 拟合数据(实现类中实际上什么也没做)
        imageScaler.fit(trainIter);
        trainIter.setPreProcessor(imageScaler);

        log.info("测试集的矢量化操作...");
        // 初始化测试集,与前面的训练集操作类似
        File testData = new File(BASE_PATH + "/mnist_png/testing");
        FileSplit testSplit = new FileSplit(testData, NativeImageLoader.ALLOWED_FORMATS, randNumGen);
        ImageRecordReader testRR = new ImageRecordReader(height, width, channels, labelMaker);
        testRR.initialize(testSplit);
        DataSetIterator testIter = new RecordReaderDataSetIterator(testRR, batchSize, 1, outputNum);
        testIter.setPreProcessor(imageScaler); // same normalization for better results

        log.info("配置神经网络");

        // 在训练中,将学习率配置为随着迭代阶梯性下降
        Map<Integer, Double> learningRateSchedule = new HashMap<>();
        learningRateSchedule.put(0, 0.06);
        learningRateSchedule.put(200, 0.05);
        learningRateSchedule.put(600, 0.028);
        learningRateSchedule.put(800, 0.0060);
        learningRateSchedule.put(1000, 0.001);

        // 超参数
        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
            .seed(seed)
            // L2正则化系数
            .l2(0.0005)
            // 梯度下降的学习率设置
            .updater(new Nesterovs(new MapSchedule(ScheduleType.ITERATION, learningRateSchedule)))
            // 权重初始化
            .weightInit(WeightInit.XAVIER)
            // 准备分层
            .list()
            // 卷积层
            .layer(new ConvolutionLayer.Builder(5, 5)
                .nIn(channels)
                .stride(1, 1)
                .nOut(20)
                .activation(Activation.IDENTITY)
                .build())
            // 下采样,即池化
            .layer(new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
                .kernelSize(2, 2)
                .stride(2, 2)
                .build())
            // 卷积层
            .layer(new ConvolutionLayer.Builder(5, 5)
                .stride(1, 1) // nIn need not specified in later layers
                .nOut(50)
                .activation(Activation.IDENTITY)
                .build())
            // 下采样,即池化
            .layer(new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
                .kernelSize(2, 2)
                .stride(2, 2)
                .build())
            // 稠密层,即全连接
            .layer(new DenseLayer.Builder().activation(Activation.RELU)
                .nOut(500)
                .build())
            // 输出
            .layer(new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
                .nOut(outputNum)
                .activation(Activation.SOFTMAX)
                .build())
            .setInputType(InputType.convolutionalFlat(height, width, channels)) // InputType.convolutional for normal image
            .build();

        MultiLayerNetwork net = new MultiLayerNetwork(conf);
        net.init();

        // 每十个迭代打印一次损失函数值
        net.setListeners(new ScoreIterationListener(10));

        log.info("神经网络共[{}]个参数", net.numParams());

        long startTime = System.currentTimeMillis();
        // 循环操作
        for (int i = 0; i < nEpochs; i++) {
            log.info("第[{}]个循环", i);
            net.fit(trainIter);
            Evaluation eval = net.evaluate(testIter);
            log.info(eval.stats());
            trainIter.reset();
            testIter.reset();
        }
        log.info("完成训练和测试,耗时[{}]毫秒", System.currentTimeMillis()-startTime);

        // 保存模型
        File ministModelPath = new File(BASE_PATH + "/minist-model.zip");
        ModelSerializer.writeModel(net, ministModelPath, true);
        log.info("最新的MINIST模型保存在[{}]", ministModelPath.getPath());
    }
}

输出模型文件和得分结果

dl2

3.编写模型预测接口

pom.xml

<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
         xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
         xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
    <parent>
        <artifactId>springboot-demo</artifactId>
        <groupId>com.et</groupId>
        <version>1.0-SNAPSHOT</version>
    </parent>
    <modelVersion>4.0.0</modelVersion>

    <artifactId>Deeplearning4j</artifactId>

    <properties>
        <maven.compiler.source>8</maven.compiler.source>
        <maven.compiler.target>8</maven.compiler.target>
        <dl4j-master.version>1.0.0-beta7</dl4j-master.version>
        <nd4j.backend>nd4j-native</nd4j.backend>
    </properties>
    <dependencies>

        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-web</artifactId>
        </dependency>

        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-autoconfigure</artifactId>
        </dependency>
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-test</artifactId>
            <scope>test</scope>
        </dependency>
        <dependency>
            <groupId>org.projectlombok</groupId>
            <artifactId>lombok</artifactId>
            <version>1.18.20</version>
        </dependency>

        <dependency>
            <groupId>ch.qos.logback</groupId>
            <artifactId>logback-classic</artifactId>
        </dependency>

        <dependency>
            <groupId>org.deeplearning4j</groupId>
            <artifactId>deeplearning4j-core</artifactId>
            <version>${dl4j-master.version}</version>
        </dependency>

        <dependency>
            <groupId>org.nd4j</groupId>
            <artifactId>${nd4j.backend}</artifactId>
            <version>${dl4j-master.version}</version>
        </dependency>

        <!--用于本地GPU-->
        <!--        <dependency>-->
        <!--            <groupId>org.deeplearning4j</groupId>-->
        <!--            <artifactId>deeplearning4j-cuda-9.2</artifactId>-->
        <!--            <version>${dl4j-master.version}</version>-->
        <!--        </dependency>-->

        <!--        <dependency>-->
        <!--            <groupId>org.nd4j</groupId>-->
        <!--            <artifactId>nd4j-cuda-9.2-platform</artifactId>-->
        <!--            <version>${dl4j-master.version}</version>-->
        <!--        </dependency>-->


    </dependencies>
</project>

cotroller

package com.et.dl4j.controller;

import com.et.dl4j.service.PredictService;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.*;
import org.springframework.web.multipart.MultipartFile;

import java.util.HashMap;
import java.util.Map;

@RestController
public class HelloWorldController {
    @RequestMapping("/hello")
    public Map<String, Object> showHelloWorld(){
        Map<String, Object> map = new HashMap<>();
        map.put("msg", "HelloWorld");
        return map;
    }
   @Autowired
    PredictService predictService;


   @PostMapping("/predict-with-black-background")
   public int predictWithBlackBackground(@RequestParam("file") MultipartFile file) throws Exception {
      // 训练模型的时候,用的数字是白字黑底,
      // 因此如果上传白字黑底的图片,可以直接拿去识别,而无需反色处理
      return predictService.predict(file, false);
   }

   @PostMapping("/predict-with-white-background")
   public int predictWithWhiteBackground(@RequestParam("file") MultipartFile file) throws Exception {
      // 训练模型的时候,用的数字是白字黑底,
      // 因此如果上传黑字白底的图片,就需要做反色处理,
      // 反色之后就是白字黑底了,可以拿去识别
      return predictService.predict(file, true);
   }
}

service

package com.et.dl4j.service;

import org.springframework.web.multipart.MultipartFile;


public interface PredictService {

    /**
     * 取得上传的图片,做转换后识别成数字
     * @param file 上传的文件
     * @param isNeedRevert 是否要做反色处理
     * @return
     */
    int predict(MultipartFile file, boolean isNeedRevert) throws Exception ;
}
package com.et.dl4j.service.impl;
import com.et.dl4j.service.PredictService;
import com.et.dl4j.util.ImageFileUtil;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;
import org.springframework.web.multipart.MultipartFile;

import javax.annotation.PostConstruct;
import java.io.File;


@Service
@Slf4j
public class PredictServiceImpl implements PredictService {

    /**
     * -1表示识别失败
     */
    private static final int RLT_INVALID = -1;

    /**
     * 模型文件的位置
     */
    @Value("${predict.modelpath}")
    private String modelPath;

    /**
     * 处理图片文件的目录
     */
    @Value("${predict.imagefilepath}")
    private String imageFilePath;

    /**
     * 神经网络
     */
    private MultiLayerNetwork net;

    /**
     * bean实例化成功就加载模型
     */
    @PostConstruct
    private void loadModel() {
        log.info("load model from [{}]", modelPath);

        // 加载模型
        try {
            net = ModelSerializer.restoreMultiLayerNetwork(new File(modelPath));
            log.info("module summary\n{}", net.summary());
        } catch (Exception exception) {
            log.error("loadModel error", exception);
        }
    }

    @Override
    public int predict(MultipartFile file, boolean isNeedRevert) throws Exception {
        log.info("start predict, file [{}], isNeedRevert [{}]", file.getOriginalFilename(), isNeedRevert);

        // 先存文件
        String rawFileName = ImageFileUtil.save(imageFilePath, file);

        if (null==rawFileName) {
            return RLT_INVALID;
        }

        // 反色处理后的文件名
        String revertFileName = null;

        // 调整大小后的文件名
        String resizeFileName;

        // 是否需要反色处理
        if (isNeedRevert) {
            // 把原始文件做反色处理,返回结果是反色处理后的新文件
            revertFileName = ImageFileUtil.colorRevert(imageFilePath, rawFileName);

            // 把反色处理后调整为28*28大小的文件
            resizeFileName = ImageFileUtil.resize(imageFilePath, revertFileName);
        } else {
            // 直接把原始文件调整为28*28大小的文件
            resizeFileName = ImageFileUtil.resize(imageFilePath, rawFileName);
        }

        // 现在已经得到了结果反色和调整大小处理过后的文件,
        // 那么原始文件和反色处理过的文件就可以删除了
        ImageFileUtil.clear(imageFilePath, rawFileName, revertFileName);

        // 取出该黑白图片的特征
        INDArray features = ImageFileUtil.getGrayImageFeatures(imageFilePath, resizeFileName);
        
        // 将特征传给模型去识别
        return net.predict(features)[0];
    }
}

application.properties

# 上传文件总的最大值
spring.servlet.multipart.max-request-size=1024MB

# 单个文件的最大值
spring.servlet.multipart.max-file-size=10MB

# 处理图片文件的目录
predict.imagefilepath=/Users/liuhaihua/Downloads/images/

# 模型所在位置
predict.modelpath=/Users/liuhaihua/Downloads/minist-model.zip

工具类

package com.et.dl4j.util;

import lombok.extern.slf4j.Slf4j;
import org.datavec.api.split.FileSplit;
import org.datavec.image.loader.NativeImageLoader;
import org.datavec.image.recordreader.ImageRecordReader;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler;
import org.springframework.web.multipart.MultipartFile;

import javax.imageio.ImageIO;
import java.awt.*;
import java.awt.image.BufferedImage;
import java.io.File;
import java.io.IOException;
import java.util.UUID;


@Slf4j
public class ImageFileUtil {

    /**
     * 调整后的文件宽度
     */
    public static final int RESIZE_WIDTH = 28;

    /**
     * 调整后的文件高度
     */
    public static final int RESIZE_HEIGHT = 28;

    /**
     * 将上传的文件存在服务器上
     * @param base 要处理的文件所在的目录
     * @param file 要处理的文件
     * @return
     */
    public static String save(String base, MultipartFile file) {

        // 检查是否为空
        if (file.isEmpty()) {
            log.error("invalid file");
            return null;
        }

        // 文件名来自原始文件
        String fileName = file.getOriginalFilename();

        // 要保存的位置
        File dest = new File(base + fileName);

        // 开始保存
        try {
            file.transferTo(dest);
        } catch (IOException e) {
            log.error("upload fail", e);
            return null;
        }

        return fileName;
    }

    /**
     * 将图片转为28*28像素
     * @param base     处理文件的目录
     * @param fileName 待调整的文件名
     * @return
     */
    public static String resize(String base, String fileName) {

        // 新文件名是原文件名在加个随机数后缀,而且扩展名固定为png
        String resizeFileName = fileName.substring(0, fileName.lastIndexOf(".")) + "-" + UUID.randomUUID() + ".png";

        log.info("start resize, from [{}] to [{}]", fileName, resizeFileName);

        try {
            // 读原始文件
            BufferedImage bufferedImage = ImageIO.read(new File(base + fileName));

            // 缩放后的实例
            Image image = bufferedImage.getScaledInstance(RESIZE_WIDTH, RESIZE_HEIGHT, Image.SCALE_SMOOTH);

            BufferedImage resizeBufferedImage = new BufferedImage(28, 28, BufferedImage.TYPE_INT_RGB);
            Graphics graphics = resizeBufferedImage.getGraphics();

            // 绘图
            graphics.drawImage(image, 0, 0, null);
            graphics.dispose();

            // 转换后的图片写文件
            ImageIO.write(resizeBufferedImage, "png", new File(base + resizeFileName));

        } catch (Exception exception) {
            log.info("resize error from [{}] to [{}], {}", fileName, resizeFileName, exception);
            resizeFileName = null;
        }

        log.info("finish resize, from [{}] to [{}]", fileName, resizeFileName);

        return resizeFileName;
    }

    /**
     * 将RGB转为int数字
     * @param alpha
     * @param red
     * @param green
     * @param blue
     * @return
     */
    private static int colorToRGB(int alpha, int red, int green, int blue) {
        int pixel = 0;

        pixel += alpha;
        pixel = pixel << 8;

        pixel += red;
        pixel = pixel << 8;

        pixel += green;
        pixel = pixel << 8;

        pixel += blue;

        return pixel;
    }

    /**
     * 反色处理
     * @param base 处理文件的目录
     * @param src 用于处理的源文件
     * @return 反色处理后的新文件
     * @throws IOException
     */
    public static String colorRevert(String base, String src) throws IOException {
        int color, r, g, b, pixel;

        // 读原始文件
        BufferedImage srcImage = ImageIO.read(new File(base + src));

        // 修改后的文件
        BufferedImage destImage = new BufferedImage(srcImage.getWidth(), srcImage.getHeight(), srcImage.getType());

        for (int i=0; i<srcImage.getWidth(); i++) {

            for (int j=0; j<srcImage.getHeight(); j++) {
                color = srcImage.getRGB(i, j);
                r = (color >> 16) & 0xff;
                g = (color >> 8) & 0xff;
                b = color & 0xff;
                pixel = colorToRGB(255, 0xff - r, 0xff - g, 0xff - b);
                destImage.setRGB(i, j, pixel);
            }
        }

        // 反射文件的名字
        String revertFileName =  src.substring(0, src.lastIndexOf(".")) + "-revert.png";

        // 转换后的图片写文件
        ImageIO.write(destImage, "png", new File(base + revertFileName));

        return revertFileName;
    }

    /**
     * 取黑白图片的特征
     * @param base
     * @param fileName
     * @return
     * @throws Exception
     */
    public static INDArray getGrayImageFeatures(String base, String fileName) throws Exception {
        log.info("start getImageFeatures [{}]", base + fileName);

        // 和训练模型时一样的设置
        ImageRecordReader imageRecordReader = new ImageRecordReader(RESIZE_HEIGHT, RESIZE_WIDTH, 1);

        FileSplit fileSplit = new FileSplit(new File(base + fileName),
                NativeImageLoader.ALLOWED_FORMATS);

        imageRecordReader.initialize(fileSplit);

        DataSetIterator dataSetIterator = new RecordReaderDataSetIterator(imageRecordReader, 1);
        dataSetIterator.setPreProcessor(new ImagePreProcessingScaler(0, 1));

        // 取特征
        return dataSetIterator.next().getFeatures();
    }

    /**
     * 批量清理文件
     * @param base      处理文件的目录
     * @param fileNames 待清理文件集合
     */
    public static void clear(String base, String...fileNames) {
        for (String fileName : fileNames) {

            if (null==fileName) {
                continue;
            }

            File file = new File(base + fileName);

            if (file.exists()) {
                file.delete();
            }
        }
    }


}

DemoApplication.java

package com.et.dl4j;

import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;

@SpringBootApplication
public class DemoApplication {

   public static void main(String[] args) {
      SpringApplication.run(DemoApplication.class, args);
   }
}

以上只是一些关键代码,所有代码请参见下面代码仓库

代码仓库

  • https://github.com/Harries/springboot-demo

4.测试

启动Spring Boot应用,上传图片测试

  • 如果用户输入的是黑底白字的图片,只需要将上述流程中的反色处理去掉即可
  • 为白底黑字图片提供专用接口predict-with-white-background
  • 为黑底白字图片提供专用接口predict-with-black-background

dl3

5.引用

  • 关于我们 - Deeplearning4j: Open-source, Distributed Deep Learning for the JVM
  • DL4J实战之三:经典卷积实例(LeNet-5)_multilayerconfiguration 参数-CSDN博客
  • Spring Boot集成DeepLearning4j实现图片数字识别 | Harries Blog™

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1884624.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

《Windows API每日一练》7.4 状态报告上使用计时器

这一节我们使用计时器&#xff0c;每隔一秒获取当前鼠标坐标位置的像素值&#xff0c;并显示在窗口&#xff0c;这就相当于是一个简单的取色器了。 本节必须掌握的知识点&#xff1a; 第47练&#xff1a;取色器 7.4.1 第47练&#xff1a;取色器 /*----------------------------…

商家团购app微信小程序模板

手机微信商家团购小程序页面&#xff0c;商家订餐外卖小程序前端模板下载。包含&#xff1a;团购主页、购物车订餐页面、我的订单、个人主页等。 商家团购app微信小程序模板

昇思25天学习打卡营第13天|ResNet50图像分类

1. 学习内容复盘 图像分类是最基础的计算机视觉应用&#xff0c;属于有监督学习类别&#xff0c;如给定一张图像(猫、狗、飞机、汽车等等)&#xff0c;判断图像所属的类别。本章将介绍使用ResNet50网络对CIFAR-10数据集进行分类。 ResNet网络介绍 ResNet50网络是2015年由微软…

使用Git从Github上克隆仓库,修改并提交修改

前言 本次任务主要是进行github提交修改的操作练习实践&#xff0c;本文章是对实践过程以及遇到的问题进行的一个记录。 在此之前&#xff0c;我已经简单使用过github&#xff0c;Git之前已经下好了&#xff0c;所以就省略一些步骤。 步骤记录 注册github账号&#xff0c;gi…

PS系统教程31

调色之色阶 调色与通道最基本的关系通道是记录颜色最基本的信息有些图片可以用通道去改变颜色信息的说明这些图像是比较高级的PS是一款图像合成软件&#xff0c;在合成过程中需要处理大量素材&#xff0c;比如要用这些素材进行抠背景&#xff0c;就要用到图层蒙版以及Alpha通道…

Go语言--格式化输出输入、类型转换

格式说明 %T操作变量所属类型自动匹配格式的不一定很正确&#xff0c;尤其是字符类型&#xff0c;本应该是整型&#xff0c;实际上他会输出数字 输入 阻塞等待用户的输入 fmt.Scanf("%d", &a)fmt.Scan(&b)不需要写格式&#xff0c;自动匹配 类型转换 类…

深入学习 Kafka(1)- 核心组件

组件概述 1. Producer&#xff08;消息生产者&#xff09; 核心作用&#xff1a;生成数据源&#xff0c;将消息发送至指定Topic。关键特性&#xff1a;支持批量发送、分区策略选择&#xff0c;以及可配置的重试逻辑&#xff0c;提高了数据传输效率和可靠性。 2. Topic&#x…

iptable精讲

SNAT策略 SNAT策略的典型应用环境 局域网主机共享单个公网IP地址接入Internet SNAT策略的原理 源地址转换&#xff0c;Source Network Address Translantion 修改数据包的源地址 部署SNAT策略 1.准备二台最小化虚拟机修改主机名 主机名&#xff1a;gw 主机名&#xff1…

SpringBoot 项目整合 MyBatisPlus 框架,附带测试示例

文章目录 一、创建 SpringBoot 项目二、添加 MyBatisPlus 依赖三、项目结构和数据库表结构四、项目代码1、application.yml2、TestController3、TbUser4、TbUserMapper5、TestServiceImpl6、TestService7、TestApplication8、TbUserMapper.xml9、MyBatisPlusTest 五、浏览器测试…

云服务出现故障这样处理

无法连接云服务器 服务器远程无法连接时&#xff0c;可通过7ECloud控制台进行连接。 常见故障现象 1、ping不通 2、ping丢包 3、部分端口telnet不通 4、全部端口telnet不通 5、广告、弹窗植入 6、域名无法访问IP访问正常 常见故障原因 1、云服务器过期、关机或者EIP被…

【技巧】ArcgisPro 字段计算器内置函数方法的调用

在arcgisPro中&#xff0c;内置了常用的几种函数方法&#xff0c;如顺序编号&#xff0c;重分类等&#xff1b;调用方法如下&#xff1a;

YOLOv3分析

参考链接&#xff1a;霹雳吧啦b站 主要参考了b站霹雳吧啦的视频《深度学习目标检测篇》。 目录 前言YOLOv3网络结构 YOLOv3 SPP 前言 YOLOv3的精度虽然已经过时&#xff0c;但思想仍旧值得学习&#xff0c;本帖记录所需所想的一些内容。 YOLOv3 网络结构 一共53层&#xff0…

软件著作权申请:保障开发者权益,促进软件创新

一、引言 在数字化时代&#xff0c;软件作为信息技术的核心&#xff0c;已成为推动社会进步和经济发展的重要力量。然而&#xff0c;随着软件产业的蓬勃发展&#xff0c;软件侵权和抄袭现象也日益严重。为了保护软件开发者的合法权益&#xff0c;促进软件产业的健康发展&#…

一维信号短时傅里叶变换域邻域降噪方法(MATLAB)

噪声在人类日常生活中无处不在,其会降低语音信号的质量和可懂度。在低信噪比的恶劣环境中,这种负面影响愈发严重。为了解决这个问题,众多研究人员在过去的几十年里提出了许多降噪算法。 根据原理的不同,降噪算法可大致分为五类:谱减法、最优滤波法、基于统计模型的方法、子空间…

【Java EE】Spring IOCDI

Spring IOC & DI 文章目录 Spring IOC & DI一、Spring是什么&#xff1f;二、IOC(控制反转)2.1 通俗理解2.2 造汽车的例子理解IOC2.3 IOC详解1. 获取Bean2. 方法注解——Bean1. 应用场景&#xff1a;2. 应用方法&#xff1a;3. 注意要点&#xff1a; 特别注意: 四、DI4…

基于springboot+vue+uniapp的超市售货管理平台

开发语言&#xff1a;Java框架&#xff1a;springbootuniappJDK版本&#xff1a;JDK1.8服务器&#xff1a;tomcat7数据库&#xff1a;mysql 5.7&#xff08;一定要5.7版本&#xff09;数据库工具&#xff1a;Navicat11开发软件&#xff1a;eclipse/myeclipse/ideaMaven包&#…

工具网源码|计算某一天是当年的第几周,第几天,黄历工具

源码简介&#xff1a; 工具网源码|计算某一天是当年的第几周&#xff0c;第几天&#xff0c;黄历工具&#xff0c;无文章系统&#xff0c;全站同过算法实现&#xff0c;并在前台展示&#xff0c;页面多&#xff0c;很适合seo。 演示站&#xff1a; https://dijizhou.wengu8.com…

从项目中学习Bus-Off的快慢恢复

0 前言 说到Bus-Off&#xff0c;大家应该都不陌生&#xff0c;使用VH6501干扰仪进行测试的文章在网上数不胜数&#xff0c;但是一般大家都是教怎么去干扰&#xff0c;但是说如何去看快慢恢复以及对快慢恢复做出解释比较少&#xff0c;因此本文以实践的视角来讲解Bus-Off的快慢恢…

期末复习---程序填空

注意&#xff1a; 1.数组后移 *p *(p-1) //把前一个数赋值到后一个数的位置上来覆盖后一个数 2.指针找最大字符 max *p while( *p){ if( max< *p) { max*p; qp;/ 用新的指针指向这个已经找到的最大位置&#xff1b;!!!!!!!!! } p; //因为开始没有next &#xff…

网络爬虫(二) 哔哩哔哩热榜高频词按照图片形状排列

我们有时候需要爬取结果生成为自定义的词云图 生成自定义的词云图通常需要以下步骤&#xff1a; 1. 爬取数据&#xff1a;使用爬虫工具或库&#xff0c;如requests、BeautifulSoup等&#xff0c;可以爬取网页、论坛、社交媒体等平台上的文本数据。 2. 数据预处理&#xff1a…