java使用fasttext实现文本分类

news2024/10/6 20:32:16

1.fastText

官网

  • fastText是一个用于有效学习单词表示和句子分类的库
  • fastText建立在现代Mac OS和Linux发行版上。因为它使用了c++ 11的特性,所以它需要一个具有良好的c++11支持的编译器

2.创建maven项目

maven配置:

<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
         xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
         xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
    <modelVersion>4.0.0</modelVersion>

    <groupId>org.example</groupId>
    <artifactId>fasttext-demo</artifactId>
    <version>1.0-SNAPSHOT</version>

    <properties>
        <maven.compiler.source>8</maven.compiler.source>
        <maven.compiler.target>8</maven.compiler.target>
        <project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
        <!-- ai.djl -->
        <djl.version>0.27.0</djl.version>
    </properties>

    <dependencies>
        <!--ai.djl -->
        <dependency>
            <groupId>ai.djl</groupId>
            <artifactId>api</artifactId>
            <version>${djl.version}</version>
        </dependency>
        <dependency>
            <groupId>ai.djl</groupId>
            <artifactId>basicdataset</artifactId>
            <version>${djl.version}</version>
        </dependency>
        <dependency>
            <groupId>ai.djl</groupId>
            <artifactId>model-zoo</artifactId>
            <version>${djl.version}</version>
        </dependency>
        <dependency>
            <groupId>ai.djl.fasttext</groupId>
            <artifactId>fasttext-engine</artifactId>
            <version>${djl.version}</version>
        </dependency>

        <!-- SLF4J绑定到Log4j -->
        <dependency>
            <groupId>org.slf4j</groupId>
            <artifactId>slf4j-reload4j</artifactId>
            <version>1.7.36</version>
        </dependency>

        <!-- Log4j核心库 -->
        <dependency>
            <groupId>log4j</groupId>
            <artifactId>log4j</artifactId>
            <version>1.2.17</version>
        </dependency>

    </dependencies>

    <!-- 使用 aliyun 的 Maven 源,提升下载速度 -->
    <repositories>
        <repository>
            <id>aliyunmaven</id>
            <name>aliyun</name>
            <url>https://maven.aliyun.com/repository/public</url>
        </repository>
    </repositories>

    <build>
        <plugins>
            <plugin>
                <groupId>org.apache.maven.plugins</groupId>
                <artifactId>maven-compiler-plugin</artifactId>
                <configuration>
                    <source>8</source>
                    <target>8</target>
                </configuration>
                <version>3.8.1</version>
            </plugin>

            <plugin>
                <groupId>org.apache.maven.plugins</groupId>
                <artifactId>maven-assembly-plugin</artifactId>
                <version>3.3.0</version>
                <configuration>
                    <descriptorRefs>
                        <descriptorRef>jar-with-dependencies</descriptorRef>
                    </descriptorRefs>
                    <!-- 可选配置,如设置JAR文件名、MANIFEST.MF信息等 -->
                    <!-- ... -->
                </configuration>
                <executions>
                    <execution>
                        <id>make-assembly</id>
                        <phase>package</phase>
                        <goals>
                            <goal>single</goal>
                        </goals>
                    </execution>
                </executions>
            </plugin>
        </plugins>
    </build>

</project>

创建DemoDataset.java:

package cn.fasttext.demo;

import ai.djl.ndarray.NDManager;
import ai.djl.training.dataset.Batch;
import ai.djl.training.dataset.RawDataset;
import ai.djl.translate.TranslateException;
import ai.djl.util.Progress;

import java.io.File;
import java.io.IOException;
import java.nio.file.Path;

/**
 * 数据集配置
 *
 * @Author tanyong
 */
public class DemoDataset implements RawDataset<Path> {

    private final String trainPath;

    public DemoDataset(String localPath) {
        this.trainPath = localPath;
    }

    @Override
    public Path getData() throws IOException {
        return new File(trainPath).toPath();
    }

    @Override
    public Iterable<Batch> getData(NDManager ndManager) throws IOException, TranslateException {
        return null;
    }

    @Override
    public void prepare(Progress progress) throws IOException, TranslateException {

    }
}

创建FastTextDemo.java用于模型学习和预测:

package cn.fasttext.demo;

import ai.djl.MalformedModelException;
import ai.djl.fasttext.FtModel;
import ai.djl.fasttext.FtTrainingConfig;
import ai.djl.fasttext.TrainFastText;
import ai.djl.fasttext.zoo.nlp.textclassification.FtTextClassification;
import ai.djl.modality.Classifications;
import ai.djl.training.TrainingResult;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.File;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.*;

/**
 * @Author tanyong
 * @Version FastTextDemo v1.0.0 $$
 */
public class FastTextDemo {

    private static final Logger logger = LoggerFactory.getLogger(FastTextDemo.class);


    private final static String MODEL_PATH = "build/model_demo.bin";

    public static boolean isWindows() {
        String osName = System.getProperty("os.name");
        return osName != null && osName.toLowerCase().startsWith("windows");
    }


    /**
     * 训练
     *
     * @param params
     * @throws IOException
     */
    public void trainClassification(Map<String, String> params) throws IOException {
        if (isWindows()) {
            throw new RuntimeException("fastText is not supported on windows");
        }

        String trainPath = params.get("trainPath");

        if (Objects.isNull(trainPath) || trainPath.length() == 0) {
            throw new RuntimeException("trainPath 不能为空!");
        }

        // 模型保存路径
        String outputDir = params.getOrDefault("outputDir", "build");
        // 模型名称
        String modelName = params.getOrDefault("modelName", "model_demo");
        // 训练次数 [5 - 50]
        String epoch = params.getOrDefault("epoch", "5");
        // 学习率 default 0.1 [0.1 - 1.0]
        String lr = params.getOrDefault("lr", "0.1");
        // ngram [1 - 5] n-gram来指任何n个连续标记的连接
        String ngram = params.getOrDefault("ngram", "1");
        /**
         *  损失函数
         * NS (Negative Sampling): Negative Sampling是一种在大规模数据集上进行训练时常用的损失函数优化方法,特别是在词嵌入和文档分类任务中。在FastText中,它被用来减少计算softmax函数时的复杂度。具体来说,对于每个目标词(正样本),模型会随机抽取一小部分非目标词(负样本)。在每次迭代中,模型不仅预测目标词,还要区分这些负样本。这样做的好处是大大减少了需要计算概率的词汇数量,从而加快训练速度,同时保持了模型的学习效果。
         * HS (Hierarchical Softmax):Hierarchical Softmax是一种替代常规softmax函数的方法,用于降低大规模词汇表上计算所有词的概率分布的复杂度。它通过构建霍夫曼树(Huffman Tree)将词汇表组织成层次结构,使得频繁出现的词在树中距离根节点较近,不常出现的词则位于较远的叶节点。在预测时,模型只需沿着树的路径计算一系列二分类问题,而非一次性计算整个词汇表的概率分布。这种方法同样可以有效提升训练效率,尤其是在词汇表非常大的场景下
         * SOFTMAX:Softmax函数是最常见的多类别分类任务损失函数之一。在FastText中,当未使用NS或HS优化时,模型直接使用softmax函数计算每个词在词汇表中的概率分布。给定词向量表示,softmax函数会将其映射到一个概率分布,使得所有类别(词)的概率和为1。模型的目标是最大化目标词的概率,同时最小化非目标词的概率。虽然softmax函数提供了完整的概率分布,但它在大规模词汇表上的计算成本较高,因此对于非常大的数据集,通常会使用NS或HS代替。
         * OVA (One-Versus-All):One-Versus-All,也称为One-vs-the-Rest或多类单标签分类,是一种将多类分类问题转化为多个二分类问题的技术。在FastText中,OVA通常用于多标签分类任务,即一个样本可能属于多个类别。对于每个类别,模型会训练一个独立的二分类器,该分类器的任务是区分当前类别与其他所有类别。在预测时,模型对每个类别应用其对应的二分类器,得到该类别是否属于样本的预测结果。OVA方法在FastText中不太常见,因为它通常与多标签分类相关,而FastText更常用于词嵌入和文档分类任务。
         */
        String loss = params.getOrDefault("loss", "hs");


        // 学习数据集路径
        DemoDataset demoDataset = new DemoDataset(trainPath);

        FtTrainingConfig config =
                FtTrainingConfig.builder()
                        .setOutputDir(Paths.get(outputDir))
                        .setModelName(modelName)
                        .optEpoch(Integer.parseInt(epoch))
                        .optLearningRate(Float.parseFloat(lr))
                        .optMaxNGramLength(Integer.parseInt(ngram))
                        .optLoss(FtTrainingConfig.FtLoss.valueOf(loss.toUpperCase()))
                        .build();

        FtTextClassification block = TrainFastText.textClassification(config, demoDataset);
        TrainingResult result = block.getTrainingResult();
        assert result.getEpoch() != Integer.parseInt(epoch) : "Epoch Error";
        String modelPath = outputDir + File.separator + modelName + ".bin";
        assert Files.exists(Paths.get(modelPath)) : "bin not found";

    }

    /**
     * load model
     *
     * @return
     * @throws MalformedModelException
     * @throws IOException
     */
    public FtModel loadModel() throws MalformedModelException, IOException {
        Path path = Paths.get(MODEL_PATH);
        assert Files.exists(path) : "bin not found";
        FtModel model = new FtModel("model_demo");
        model.load(path);
        return model;
    }


    public static void main(String[] args) throws IOException {
        Map<String, String> params = new HashMap<>(16);
        for (String arg : args) {
            String[] param = arg.split("=");
            if (Objects.nonNull(param) && param.length == 2) {
                params.put(param[0], param[1]);
            }
        }
        String type = params.get("type");
        if (Objects.isNull(type) || type.length() == 0) {
            throw new RuntimeException("type not null");
        }
        FastTextDemo fastTextDemo = null;
        switch (type) {
            case "train":
                fastTextDemo = new FastTextDemo();
                fastTextDemo.trainClassification(params);
                break;
            case "classification":
                fastTextDemo = new FastTextDemo();
                try {
                    FtModel model = fastTextDemo.loadModel();
                    Scanner scanner = new Scanner(System.in);
                    while (true) {
                        logger.info("\n请输入一段文字(输入 'quit' 退出程序):");
                        String inputString = scanner.nextLine();

                        if ("quit".equalsIgnoreCase(inputString)) {
                            break; // 用户输入 'quit' 时,跳出循环,结束程序
                        }
                        if (Objects.isNull(inputString) || inputString.length() == 0) {
                            continue;
                        }
                        try {
                            Classifications result = ((FtTextClassification) model.getBlock()).classify(inputString, 5);
                            logger.info(result.getAsString());
                        } catch (Exception e) {
                            e.printStackTrace();
                        }
                    }
                    scanner.close();

                } catch (Exception e) {
                    throw new RuntimeException(e);
                }
                break;
            default:
                break;
        }

    }
}

3.打包部署到服务器

打包:

mvn clean package

执行命令成功生成fasttext-demo-1.0-SNAPSHOT-jar-with-dependencies.jar
在这里插入图片描述
训练:
下载数据集:

wget https://dl.fbaipublicfiles.com/fasttext/data/cooking.stackexchange.tar.gz && tar xvzf cooking.stackexchange.tar.gz
head cooking.stackexchange.txt

文本文件的每一行都包含一个标签列表,后面跟着相应的文档。所有标签都以__label__前缀开头。文本内容需要分词,用空格隔开.
执行训练命令:

java -cp fasttext-demo-1.0-SNAPSHOT-jar-with-dependencies.jar cn.fasttext.demo.FastTextDemo trainPath=/home/hadoop/app/fasttext-demo/train.txt type=train epoch=25 lr=0.5 ngram=1 loss=hs

出现异常:

Exception in thread "main" java.lang.UnsatisfiedLinkError: /root/.djl.ai/fasttext/0.9.2-0.27.0/libjni_fasttext.so: /lib64/libstdc++.so.6: version `GLIBCXX_3.4.20' not found (required by /root/.djl.ai/fasttext/0.9.2-0.27.0/libjni_fasttext.so)
	at java.lang.ClassLoader$NativeLibrary.load(Native Method)
	at java.lang.ClassLoader.loadLibrary0(ClassLoader.java:1934)
	at java.lang.ClassLoader.loadLibrary(ClassLoader.java:1817)
	at java.lang.Runtime.load0(Runtime.java:782)
	at java.lang.System.load(System.java:1100)
	at ai.djl.fasttext.jni.LibUtils.loadLibrary(LibUtils.java:57)
	at ai.djl.fasttext.jni.FtWrapper.<clinit>(FtWrapper.java:29)
	at ai.djl.fasttext.zoo.nlp.textclassification.FtTextClassification.fit(FtTextClassification.java:66)
	at ai.djl.fasttext.TrainFastText.textClassification(TrainFastText.java:37)
	at cn.fasttext.demo.FastTextDemo.trainClassification(FastTextDemo.java:77)
	at cn.fasttext.demo.FastTextDemo.main(FastTextDemo.java:101)

c++的版本太低需要升级版本
查询GLIBCXX版本,fasttext要求GLIBCXX_3.4.20,查询出来只有GLIBCXX_3.4.19

strings /usr/lib64/libstdc++.so.6 |grep GLIBCXX*

在这里插入图片描述
下载 高版本:

wget http://www.vuln.cn/wp-content/uploads/2019/08/libstdc.so_.6.0.26.zip
unzip libstdc.so_.6.0.26.zip
#复制到/usr/lib64
cp libstdc++.so.6.0.26 /usr/lib64

查看软连接

ls -l | grep libstdc++

在这里插入图片描述
重新建立高版本的软连接:

ln -sf /usr/lib64/libstdc++.so.6.0.26 /usr/lib64/libstdc++.so.6

重新查看软连接:

ls -l | grep libstdc++

在这里插入图片描述
通过命令查询:

strings /usr/lib64/libstdc++.so.6 |grep GLIBCXX*

在这里插入图片描述
问题解决

重新执行训练命令输出结果:
在这里插入图片描述
build文件夹下能查看到模型文件,model_demo.bin model_demo.vec

预测:
执行命令

java -cp fasttext-demo-1.0-SNAPSHOT-jar-with-dependencies.jar cn.fasttext.demo.FastTextDemo trainPath=/home/hadoop/app/fasttext-demo/train.txt type=classification

运行结果:
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1607546.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

14 Php学习:表单

表单 PHP 表单是用于收集用户输入的工具&#xff0c;通常用于网站开发。PHP 可以与 HTML 表单一起使用&#xff0c;用于处理用户提交的数据。通过 PHP 表单&#xff0c;您可以创建各种类型的表单&#xff0c;包括文本输入框、复选框、下拉菜单等&#xff0c;以便用户可以填写和…

AIGC算法1:Layer normalization

1. Layer Normalization μ E ( X ) ← 1 H ∑ i 1 n x i σ ← Var ⁡ ( x ) 1 H ∑ i 1 H ( x i − μ ) 2 ϵ y x − E ( x ) Var ⁡ ( X ) ϵ ⋅ γ β \begin{gathered}\muE(X) \leftarrow \frac{1}{H} \sum_{i1}^n x_i \\ \sigma \leftarrow \operatorname{Var}(…

Java | Leetcode Java题解之第35题搜索插入位置

题目&#xff1a; 题解&#xff1a; class Solution {public int searchInsert(int[] nums, int target) {int n nums.length;int left 0, right n - 1, ans n;while (left < right) {int mid ((right - left) >> 1) left;if (target < nums[mid]) {ans mi…

数字化转型对企业产生的影响

一、引言 在信息化、网络化的时代背景下&#xff0c;数字化转型已成为企业发展的必由之路。随着云计算、大数据、人工智能等技术的快速发展&#xff0c;数字化转型不仅改变了企业的运营方式&#xff0c;更深刻影响着企业的核心竞争力。本文将探讨数字化转型对企业产生的影响&a…

3D开发工具HOOPS助力CAM软件优化制造流程

在现代制造业中&#xff0c;计算机辅助制造&#xff08;CAM&#xff09;软件的发展已成为提高生产效率和产品质量的关键。为了满足不断增长的需求和日益复杂的制造流程&#xff0c;CAM软件需要具备高效的CAD数据导入、云端协作、移动应用支持以及丰富的文档生成能力。 Tech So…

羊大师分析,4月的羊奶好喝吗?

羊大师分析&#xff0c;4月的羊奶好喝吗&#xff1f; 4月的羊奶同样好喝。羊奶的口感和品质并不完全取决于月份&#xff0c;而更多地与奶源的品质、生产工艺以及保存方式等因素有关。羊大师作为知名品牌&#xff0c;一直以来都注重提供高品质的羊奶产品。 在4月这个春季时节&a…

redis写入和查询

import redis #redis的表名 redis_biao "Ruijieac_sta" #redis连接信息 redis_obj redis.StrictRedis(hostIP地址, port6379, db1, password密码) # keyytressdfg # value22 ##写入 # redis_obj.hset(redis_biao, key, value) #查询 req_redisredis_obj.hget(red…

【SGDR】《SGDR:Stochastic Gradient Descent with Warm Restarts》

arXiv-2016 code: https://github.com/loshchil/SGDR/blob/master/SGDR_WRNs.py 文章目录 1 Background and Motivation2 Related Work3 Advantages / Contributions4 Method5 Experiments5.1 Datasets and Metric5.2 Single-Model Results5.3 Ensemble Results5.4 Experiment…

Modality-Aware Contrastive Instance Learning with Self-Distillation ... 论文阅读

Modality-Aware Contrastive Instance Learning with Self-Distillation for Weakly-Supervised Audio-Visual Violence Detection 论文阅读 ABSTRACT1 INTRODUCTION2 RELATEDWORKS2.1 Weakly-Supervised Violence Detection2.2 Contrastive Learning2.3 Cross-Modality Knowle…

基于Java SpringBoot+Vue的校园周边美食探索及分享平台的研究与实现,附源码

博主介绍&#xff1a;✌IT徐师兄、7年大厂程序员经历。全网粉丝15W、csdn博客专家、掘金/华为云//InfoQ等平台优质作者、专注于Java技术领域和毕业项目实战✌ &#x1f345;文末获取源码联系&#x1f345; &#x1f447;&#x1f3fb; 精彩专栏推荐订阅&#x1f447;&#x1f3…

vue+node使用RSA非对称加密,实现登录接口加密密码

背景 登录接口&#xff0c;密码这种重要信息不可以用明文传输&#xff0c;必须加密处理。 这里就可以使用RSA非对称加密&#xff0c;后端生成公钥和私钥。 公钥&#xff1a;给前端&#xff0c;公钥可以暴露出来&#xff0c;没有影响&#xff0c;因为公钥加密的数据只有私钥才…

类和对象(中)(构造函数、析构函数和拷贝构造函数)

1.类的六个默认成员函数 任何类在什么都不写时&#xff0c;编译器会自动生成以下6个默认成员函数。 //空类 class Date{}; 默认成员函数&#xff1a;用户没有显示实现&#xff0c;编译器会自动生成的成员函数称为默认成员函数 2.构造函数 构造函数 是一个 特殊的成员函数&a…

网络分析工具

为了实现业务目标&#xff0c;每天都要在网络上执行大量操作&#xff0c;网络管理员很难了解网络中实际发生的情况、谁消耗的带宽最多&#xff0c;并分析是否正在发生任何可能导致带宽拥塞的活动。对于大型企业和分布式网络来说&#xff0c;这些挑战是多方面的&#xff0c;为了…

[Leetcode]用栈实现队列

用栈实现队列&#xff1a; 请你仅使用两个栈实现先入先出队列。队列应当支持一般队列支持的所有操作&#xff08;push、pop、peek、empty&#xff09;&#xff1a; 实现 MyQueue 类&#xff1a; void push(int x) 将元素 x 推到队列的末尾int pop() 从队列的开头移除并返回元…

【智能算法】鸡群优化算法(CSO)原理及实现

目录 1.背景2.算法原理2.1算法思想2.2算法过程 3.结果展示4.参考文献 1.背景 2014年&#xff0c;X Meng等人受到鸡群社会行为启发&#xff0c;提出了鸡群优化算法&#xff08;Chicken Swarm Optimization, CSO&#xff09;。 2.算法原理 2.1算法思想 CSO算法的思想是基于对…

(六)PostgreSQL的组织结构(3)-默认角色和schema

PostgreSQL的组织结构(3)-默认角色和schema 基础信息 OS版本&#xff1a;Red Hat Enterprise Linux Server release 7.9 (Maipo) DB版本&#xff1a;16.2 pg软件目录&#xff1a;/home/pg16/soft pg数据目录&#xff1a;/home/pg16/data 端口&#xff1a;57771 默认角色 Post…

软考135-上午题-【软件工程】-软件配置管理

备注&#xff1a; 该部分考题内容在教材中找不到。直接背题目 一、配置数据库 配置数据库可以分为以下三类&#xff1a; (1) 开发库 专供开发人员使用&#xff0c;其中的信息可能做频繁修改&#xff0c;对其控制相当宽松 (2) 受控库 在生存期某一阶段工作结束时发布的阶段产…

手机拍摄视频怎么做二维码?现场录制视频一键生成二维码

随着手机摄像头的像素不断提升&#xff0c;现在经常会通过手机的拍摄视频&#xff0c;然后发送给其他人查看。当我们想要将一个视频分享给多人去查看时&#xff0c;如果一个个去发送会比较的浪费时间&#xff0c;而且对方还需要下载接受视频后才可以查看&#xff0c;时间成本高…

简化PLC图纸绘制流程:利用SOLIDWORKS Electrical提升效率与准确性

效率一向是工程师比较注重的问题&#xff0c;为了提高工作效率&#xff0c;工程师绞尽脑汁。而在SOLIDWORKS Electrical绘制plc原理图时能有效提高PLC图纸的出图效率&#xff0c;并且可以减少数据误差。 在SOLIDWORKS Electrical绘制PLC图纸时&#xff0c;可以先创建PLC输入/输…

域名被污染了只能换域名吗?

域名污染是指域名的解析结果受到恶意干扰或篡改&#xff0c;使得用户在访问相关网站时出现异常。很多域名遭遇过污染的情况&#xff0c;但是并不知道是域名污染&#xff0c;具体来说&#xff0c;域名污染可能表现为以下情况&#xff1a;用户无法通过输入正确的域名访问到目标网…