根据《attention is all you need》论文而形成的transformers框架在chat-gpt应用中大放异彩,目前transformers框架已经成了炙手可热的框架。它不仅在nlp方面很作用很大,根据官网的介绍,它还可以做很多事情,比如图片分类,目标检测。
下面结合官网示例,给出两个简单的示例,一个是文本处理,另一个是目标检测。
transformers框架提供了pipeline的方式,可以快速运用一个模型到输入对象上。官方的原话是:
To immediately use a model on a given input (text, image, audio, ...), we provide the
pipeline
API
在进行示例之前,我们需要安装transformers框架,本机安装的是transformers=4.26.1
pip install transformers==4.26.1
第一个文本处理的例子,利用transformers快速区分积极和消极的文本内容。如下所示,我们输入一段文字,transformer会给出判断:
from transformers import pipeline
classifier = pipeline('sentiment-analysis')
res = classifier('we are happy to indroduce pipeline to the transformers repository.')
print(res)
运行这段代码,可以得到如下结果:
运行这个示例,会先去下载一些模型文件,这里默认会下载distilbert-base-uncased-finetuned-sst-2-english模型文件,并存放到本机用户目录下的.cache\hugginface\hub\models--distilbert-base-uncased-finetuned-sst-2-english目录下,如下所示:
换一句话试一下:i am sorry to hear that you are sick.
这次识别文本语义为negative,也就是消极或者反面的,符合预期。
第二个示例是目标检测的,先准备一张图片,然后利用目标检测模型来识别。
import requests
from PIL import Image
from transformers import pipeline
# Download an image with cute cats
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/coco_sample.png"
image_data = requests.get(url, stream=True).raw
image = Image.open(image_data)
# Allocate a pipeline for object detection
object_detector = pipeline('object-detection')
res = object_detector(image)
print(res)
这个图片如下所示:
打印结果如下:
[{'score': 0.9982201457023621,
'label': 'remote',
'box': {'xmin': 40, 'ymin': 70, 'xmax': 175, 'ymax': 117}},
{'score': 0.9960021376609802,
'label': 'remote',
'box': {'xmin': 333, 'ymin': 72, 'xmax': 368, 'ymax': 187}},
{'score': 0.9954745173454285,
'label': 'couch',
'box': {'xmin': 0, 'ymin': 1, 'xmax': 639, 'ymax': 473}},
{'score': 0.9988006353378296,
'label': 'cat',
'box': {'xmin': 13, 'ymin': 52, 'xmax': 314, 'ymax': 470}},
{'score': 0.9986783862113953,
'label': 'cat',
'box': {'xmin': 345, 'ymin': 23, 'xmax': 640, 'ymax': 368}}]
分别识别了图中的沙发、遥控器、猫咪。就像下面这张图一样:
和第一个示例类似,这个示例运行的时候,会下载facebook/detr-resnet-50模型文件。并存放在用户目录下的.cache\huggingface\hub\models--facebook--detr-resnet-50目录下。还会下载一个名为resnet50_a1_0-14fe96d1.pth的文件放到.cache\torch\hub\checkpoints目录下。