随着chat-gpt等机器人对话框架的流行,让一个名为gradio的框架也火热起来,这个框架可以开启一个http服务,并且带输入输出界面,可以让对话类的人工智能项目快速运行。
gradio号称可以快速部署ai可视化项目。
下面通过两个示例来感受一下,首先我们需要安装gradio库。
pip install gradio
接着编写如下的代码,用户输入一个字符串xxx,提交之后,输出一个hello,xxx 。
import gradio as gr
def hello(name):
return "hello," + name + "!"
def launch():
demo = gr.Interface(fn=hello, inputs='text', outputs='text')
demo.launch()
if __name__ == '__main__':
launch()
运行这段代码,可以开启7860端口监听http服务, 浏览器访问http://localhost:7860,可以打开如下界面:
再编写一个示例,是关于图像识别的,代码如下:
import torch
from PIL import Image
from torchvision import transforms
import gradio as gr
import json
with open('imagenet-simple-labels.json', 'r') as load_f:
labels = json.load(load_f)
model = torch.hub.load("pytorch/vision:v0.6.0", "resnet18", pretrained=True).eval()
def predict(inp):
inp = Image.fromarray(inp.astype("uint8"), "RGB")
inp = transforms.ToTensor()(inp).unsqueeze(0)
with torch.no_grad():
prediction = torch.nn.functional.softmax(model(inp)[0], dim=0)
return {labels[i]: float(prediction[i]) for i in range(1000)}
inputs = gr.Image()
outputs = gr.Label(num_top_classes=3)
demo = gr.Interface(fn=predict, inputs=inputs, outputs=outputs)
if __name__ == '__main__':
demo.launch()
运行代码,会下载pytorch/vision:v0.6.0版本,并下载一个resnet18的模型文件:resnet18-f37072fd.pth到用户目录下的.cache\torch\hub\checkpoints\目录下。
运行打印信息如下:
我们打开浏览器http://localhost:7860,在界面上选择我们事先准备好的豹子和狗的图片:
这里识别了豹子,显示cheetah。
换一只狗的再试一下:
识别结果为一只拉布拉多。
代码中设置了三个最可能的结果,outputs = gr.Label(num_top_classes=3),所以这里会列出最有可能的三种情况。
以上代码运行的时候报了警告:
D:\Program Files\Python\Python310\lib\site-packages\torchvision\models\_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.
warnings.warn(
D:\Program Files\Python\Python310\lib\site-packages\torchvision\models\_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=ResNet18_Weights.IMAGENET1K_V1`. You can also use `weights=ResNet18_Weights.DEFAULT` to get the most up-to-date weights.
warnings.warn(msg)
意思是 torch.hub.load加载模型的时候,pretrained参数过时了,可以使用weights=ResNet18_Weights.DEFAULT替代。
修改代码之后,就不报警告了。如下所示:
官网的例子,文中有个文件来自https://git.io/JJkYN,现在已经无法下载了,但是它可以直接在github找到:https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json
这里就是提前下载,然后通过json读取,内容是1000个目标标签。