由于工作需要,需要对embedding模型进行微调,我调用了几种方案,都比较繁琐。先记录一个相对简单的方案。以下内容并不一定正确,请刷到的大佬给予指正,不胜感激!!!
一.对BGE模型,如bge-large-zh 、bge-large-en
二.对sentensce embedding bert model ,如多语言模型 distiluse-base-multilingual-cased-v1
三.对于sentence embedding bert model 使用 towhee 进行微调,以下主要对这个方案进行阐述:
做微调之前需要准备微调样本数据,准备方式,我目前思考跟第二种方案是一样的。我偷懒,先验证代码可以跑通,所以用了example dataset
1. git clone 代码.
参考链接:
sentence-embedding/sbert - sbert - Towhee
git clone https://towhee.io/sentence-embedding/sbert.git
2. 配置python环境
conda create -n sentence-embedding-3.9 python=3.9 -y
#进入代码根目录
pip install -r requirement.txt
pip install towhee
3.运行微调代码(其实就是继续训练原有模型)
修改微调核心代码如下:
if __name__ == '__main__':
PROJ_DIR = '/data2/04_embedding/finetune/sentence-embedding/'
sys.path.append(os.path.join(PROJ_DIR, 'sbert'))
from sentence_transformers import util
# op = STransformers(model_name='nli-distilroberta-base-v2')
op = STransformers(model_name='distiluse-base-multilingual-cased-v1')
# Check if dataset exsist. If not, download and extract it
sts_dataset_path = 'datasets/stsbenchmark.tsv.gz'
if not os.path.exists(sts_dataset_path):
util.http_get('https://sbert.net/datasets/stsbenchmark.tsv.gz', sts_dataset_path)
training_config = {
'sts_dataset_path': sts_dataset_path,
'train_batch_size': 16,
'num_epochs': 4,
'model_save_path': './output'
}
op.train(training_config)
python s_bert.py
发生下载数据集错误
urllib3.exceptions.MaxRetryError: HTTPSConnectionPool(host='sbert.net', port=443): Max retries exceeded with url: /datasets/stsbenchmark.tsv.gz (Caused by NewConnectionError('<urllib3.connection.HTTPSConnection object at 0x7f6bf717ad00>: Failed to establish a new connection: [Errno 101] Network is unreachable'))
手动下载,并放到datasets目录下。
发生相对包引用错误
You may be able to resolve this warning by setting `model_config['protected_namespaces'] = ()`.
warnings.warn(
Traceback (most recent call last):
File "/data2/04_embedding/finetune/sentence-embedding/sbert/s_bert.py", line 281, in <module>
op.train(training_config)
File "/data2/04_embedding/finetune/sentence-embedding/sbert/s_bert.py", line 260, in train
from .train_sts_task import train_sts
ImportError: attempted relative import with no known parent package
参考我的另一个记录:ImportError: attempted relative import with no known parent package-CSDN博客
微调结果:
剩下的就是测试模型了(待续)。。。