1.fastText
官网
- fastText是一个用于有效学习单词表示和句子分类的库
- fastText建立在现代Mac OS和Linux发行版上。因为它使用了c++ 11的特性,所以它需要一个具有良好的c++11支持的编译器
2.创建maven项目
maven配置:
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<groupId>org.example</groupId>
<artifactId>fasttext-demo</artifactId>
<version>1.0-SNAPSHOT</version>
<properties>
<maven.compiler.source>8</maven.compiler.source>
<maven.compiler.target>8</maven.compiler.target>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<!-- ai.djl -->
<djl.version>0.27.0</djl.version>
</properties>
<dependencies>
<!--ai.djl -->
<dependency>
<groupId>ai.djl</groupId>
<artifactId>api</artifactId>
<version>${djl.version}</version>
</dependency>
<dependency>
<groupId>ai.djl</groupId>
<artifactId>basicdataset</artifactId>
<version>${djl.version}</version>
</dependency>
<dependency>
<groupId>ai.djl</groupId>
<artifactId>model-zoo</artifactId>
<version>${djl.version}</version>
</dependency>
<dependency>
<groupId>ai.djl.fasttext</groupId>
<artifactId>fasttext-engine</artifactId>
<version>${djl.version}</version>
</dependency>
<!-- SLF4J绑定到Log4j -->
<dependency>
<groupId>org.slf4j</groupId>
<artifactId>slf4j-reload4j</artifactId>
<version>1.7.36</version>
</dependency>
<!-- Log4j核心库 -->
<dependency>
<groupId>log4j</groupId>
<artifactId>log4j</artifactId>
<version>1.2.17</version>
</dependency>
</dependencies>
<!-- 使用 aliyun 的 Maven 源,提升下载速度 -->
<repositories>
<repository>
<id>aliyunmaven</id>
<name>aliyun</name>
<url>https://maven.aliyun.com/repository/public</url>
</repository>
</repositories>
<build>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-compiler-plugin</artifactId>
<configuration>
<source>8</source>
<target>8</target>
</configuration>
<version>3.8.1</version>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-assembly-plugin</artifactId>
<version>3.3.0</version>
<configuration>
<descriptorRefs>
<descriptorRef>jar-with-dependencies</descriptorRef>
</descriptorRefs>
<!-- 可选配置,如设置JAR文件名、MANIFEST.MF信息等 -->
<!-- ... -->
</configuration>
<executions>
<execution>
<id>make-assembly</id>
<phase>package</phase>
<goals>
<goal>single</goal>
</goals>
</execution>
</executions>
</plugin>
</plugins>
</build>
</project>
创建DemoDataset.java:
package cn.fasttext.demo;
import ai.djl.ndarray.NDManager;
import ai.djl.training.dataset.Batch;
import ai.djl.training.dataset.RawDataset;
import ai.djl.translate.TranslateException;
import ai.djl.util.Progress;
import java.io.File;
import java.io.IOException;
import java.nio.file.Path;
/**
* 数据集配置
*
* @Author tanyong
*/
public class DemoDataset implements RawDataset<Path> {
private final String trainPath;
public DemoDataset(String localPath) {
this.trainPath = localPath;
}
@Override
public Path getData() throws IOException {
return new File(trainPath).toPath();
}
@Override
public Iterable<Batch> getData(NDManager ndManager) throws IOException, TranslateException {
return null;
}
@Override
public void prepare(Progress progress) throws IOException, TranslateException {
}
}
创建FastTextDemo.java用于模型学习和预测:
package cn.fasttext.demo;
import ai.djl.MalformedModelException;
import ai.djl.fasttext.FtModel;
import ai.djl.fasttext.FtTrainingConfig;
import ai.djl.fasttext.TrainFastText;
import ai.djl.fasttext.zoo.nlp.textclassification.FtTextClassification;
import ai.djl.modality.Classifications;
import ai.djl.training.TrainingResult;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.File;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.*;
/**
* @Author tanyong
* @Version FastTextDemo v1.0.0 $$
*/
public class FastTextDemo {
private static final Logger logger = LoggerFactory.getLogger(FastTextDemo.class);
private final static String MODEL_PATH = "build/model_demo.bin";
public static boolean isWindows() {
String osName = System.getProperty("os.name");
return osName != null && osName.toLowerCase().startsWith("windows");
}
/**
* 训练
*
* @param params
* @throws IOException
*/
public void trainClassification(Map<String, String> params) throws IOException {
if (isWindows()) {
throw new RuntimeException("fastText is not supported on windows");
}
String trainPath = params.get("trainPath");
if (Objects.isNull(trainPath) || trainPath.length() == 0) {
throw new RuntimeException("trainPath 不能为空!");
}
// 模型保存路径
String outputDir = params.getOrDefault("outputDir", "build");
// 模型名称
String modelName = params.getOrDefault("modelName", "model_demo");
// 训练次数 [5 - 50]
String epoch = params.getOrDefault("epoch", "5");
// 学习率 default 0.1 [0.1 - 1.0]
String lr = params.getOrDefault("lr", "0.1");
// ngram [1 - 5] n-gram来指任何n个连续标记的连接
String ngram = params.getOrDefault("ngram", "1");
/**
* 损失函数
* NS (Negative Sampling): Negative Sampling是一种在大规模数据集上进行训练时常用的损失函数优化方法,特别是在词嵌入和文档分类任务中。在FastText中,它被用来减少计算softmax函数时的复杂度。具体来说,对于每个目标词(正样本),模型会随机抽取一小部分非目标词(负样本)。在每次迭代中,模型不仅预测目标词,还要区分这些负样本。这样做的好处是大大减少了需要计算概率的词汇数量,从而加快训练速度,同时保持了模型的学习效果。
* HS (Hierarchical Softmax):Hierarchical Softmax是一种替代常规softmax函数的方法,用于降低大规模词汇表上计算所有词的概率分布的复杂度。它通过构建霍夫曼树(Huffman Tree)将词汇表组织成层次结构,使得频繁出现的词在树中距离根节点较近,不常出现的词则位于较远的叶节点。在预测时,模型只需沿着树的路径计算一系列二分类问题,而非一次性计算整个词汇表的概率分布。这种方法同样可以有效提升训练效率,尤其是在词汇表非常大的场景下
* SOFTMAX:Softmax函数是最常见的多类别分类任务损失函数之一。在FastText中,当未使用NS或HS优化时,模型直接使用softmax函数计算每个词在词汇表中的概率分布。给定词向量表示,softmax函数会将其映射到一个概率分布,使得所有类别(词)的概率和为1。模型的目标是最大化目标词的概率,同时最小化非目标词的概率。虽然softmax函数提供了完整的概率分布,但它在大规模词汇表上的计算成本较高,因此对于非常大的数据集,通常会使用NS或HS代替。
* OVA (One-Versus-All):One-Versus-All,也称为One-vs-the-Rest或多类单标签分类,是一种将多类分类问题转化为多个二分类问题的技术。在FastText中,OVA通常用于多标签分类任务,即一个样本可能属于多个类别。对于每个类别,模型会训练一个独立的二分类器,该分类器的任务是区分当前类别与其他所有类别。在预测时,模型对每个类别应用其对应的二分类器,得到该类别是否属于样本的预测结果。OVA方法在FastText中不太常见,因为它通常与多标签分类相关,而FastText更常用于词嵌入和文档分类任务。
*/
String loss = params.getOrDefault("loss", "hs");
// 学习数据集路径
DemoDataset demoDataset = new DemoDataset(trainPath);
FtTrainingConfig config =
FtTrainingConfig.builder()
.setOutputDir(Paths.get(outputDir))
.setModelName(modelName)
.optEpoch(Integer.parseInt(epoch))
.optLearningRate(Float.parseFloat(lr))
.optMaxNGramLength(Integer.parseInt(ngram))
.optLoss(FtTrainingConfig.FtLoss.valueOf(loss.toUpperCase()))
.build();
FtTextClassification block = TrainFastText.textClassification(config, demoDataset);
TrainingResult result = block.getTrainingResult();
assert result.getEpoch() != Integer.parseInt(epoch) : "Epoch Error";
String modelPath = outputDir + File.separator + modelName + ".bin";
assert Files.exists(Paths.get(modelPath)) : "bin not found";
}
/**
* load model
*
* @return
* @throws MalformedModelException
* @throws IOException
*/
public FtModel loadModel() throws MalformedModelException, IOException {
Path path = Paths.get(MODEL_PATH);
assert Files.exists(path) : "bin not found";
FtModel model = new FtModel("model_demo");
model.load(path);
return model;
}
public static void main(String[] args) throws IOException {
Map<String, String> params = new HashMap<>(16);
for (String arg : args) {
String[] param = arg.split("=");
if (Objects.nonNull(param) && param.length == 2) {
params.put(param[0], param[1]);
}
}
String type = params.get("type");
if (Objects.isNull(type) || type.length() == 0) {
throw new RuntimeException("type not null");
}
FastTextDemo fastTextDemo = null;
switch (type) {
case "train":
fastTextDemo = new FastTextDemo();
fastTextDemo.trainClassification(params);
break;
case "classification":
fastTextDemo = new FastTextDemo();
try {
FtModel model = fastTextDemo.loadModel();
Scanner scanner = new Scanner(System.in);
while (true) {
logger.info("\n请输入一段文字(输入 'quit' 退出程序):");
String inputString = scanner.nextLine();
if ("quit".equalsIgnoreCase(inputString)) {
break; // 用户输入 'quit' 时,跳出循环,结束程序
}
if (Objects.isNull(inputString) || inputString.length() == 0) {
continue;
}
try {
Classifications result = ((FtTextClassification) model.getBlock()).classify(inputString, 5);
logger.info(result.getAsString());
} catch (Exception e) {
e.printStackTrace();
}
}
scanner.close();
} catch (Exception e) {
throw new RuntimeException(e);
}
break;
default:
break;
}
}
}
3.打包部署到服务器
打包:
mvn clean package
执行命令成功生成fasttext-demo-1.0-SNAPSHOT-jar-with-dependencies.jar
训练:
下载数据集:
wget https://dl.fbaipublicfiles.com/fasttext/data/cooking.stackexchange.tar.gz && tar xvzf cooking.stackexchange.tar.gz
head cooking.stackexchange.txt
文本文件的每一行都包含一个标签列表,后面跟着相应的文档。所有标签都以__label__前缀开头。文本内容需要分词,用空格隔开.
执行训练命令:
java -cp fasttext-demo-1.0-SNAPSHOT-jar-with-dependencies.jar cn.fasttext.demo.FastTextDemo trainPath=/home/hadoop/app/fasttext-demo/train.txt type=train epoch=25 lr=0.5 ngram=1 loss=hs
出现异常:
Exception in thread "main" java.lang.UnsatisfiedLinkError: /root/.djl.ai/fasttext/0.9.2-0.27.0/libjni_fasttext.so: /lib64/libstdc++.so.6: version `GLIBCXX_3.4.20' not found (required by /root/.djl.ai/fasttext/0.9.2-0.27.0/libjni_fasttext.so)
at java.lang.ClassLoader$NativeLibrary.load(Native Method)
at java.lang.ClassLoader.loadLibrary0(ClassLoader.java:1934)
at java.lang.ClassLoader.loadLibrary(ClassLoader.java:1817)
at java.lang.Runtime.load0(Runtime.java:782)
at java.lang.System.load(System.java:1100)
at ai.djl.fasttext.jni.LibUtils.loadLibrary(LibUtils.java:57)
at ai.djl.fasttext.jni.FtWrapper.<clinit>(FtWrapper.java:29)
at ai.djl.fasttext.zoo.nlp.textclassification.FtTextClassification.fit(FtTextClassification.java:66)
at ai.djl.fasttext.TrainFastText.textClassification(TrainFastText.java:37)
at cn.fasttext.demo.FastTextDemo.trainClassification(FastTextDemo.java:77)
at cn.fasttext.demo.FastTextDemo.main(FastTextDemo.java:101)
c++的版本太低需要升级版本
查询GLIBCXX版本,fasttext要求GLIBCXX_3.4.20,查询出来只有GLIBCXX_3.4.19
strings /usr/lib64/libstdc++.so.6 |grep GLIBCXX*
下载 高版本:
wget http://www.vuln.cn/wp-content/uploads/2019/08/libstdc.so_.6.0.26.zip
unzip libstdc.so_.6.0.26.zip
#复制到/usr/lib64
cp libstdc++.so.6.0.26 /usr/lib64
查看软连接
ls -l | grep libstdc++
重新建立高版本的软连接:
ln -sf /usr/lib64/libstdc++.so.6.0.26 /usr/lib64/libstdc++.so.6
重新查看软连接:
ls -l | grep libstdc++
通过命令查询:
strings /usr/lib64/libstdc++.so.6 |grep GLIBCXX*
问题解决
重新执行训练命令输出结果:
build文件夹下能查看到模型文件,model_demo.bin model_demo.vec
预测:
执行命令
java -cp fasttext-demo-1.0-SNAPSHOT-jar-with-dependencies.jar cn.fasttext.demo.FastTextDemo trainPath=/home/hadoop/app/fasttext-demo/train.txt type=classification
运行结果: