以下是一个使用 NIM 平台的生成式 AI模型构建的简单 demo。这个 demo 实现了文生图,通过解析用户需求来判断是否需要进行画图。这里使用 Python 和 FastAPI框架来搭建一个简单的 web 应用。
项目结构
work/
│
├── images/
├── chat.py
└── chat.html
安装依赖
pip install fastapi uvicorn openai
创建聊天应用
from openai import OpenAI
from fastapi import FastAPI, Query
import uvicorn
from fastapi.middleware.cors import CORSMiddleware
import requests, base64
from datetime import datetime
import json
invoke_url = "https://ai.api.nvidia.com/v1/genai/nvidia/consistory"
headers = {
"Authorization": "Bearer 你的apikey",
"Accept": "application/json",
}
def draw_image(subject_prompt, subject_tokens, style_prompt, scene_prompt1):
print("subject_prompt=====",subject_prompt)
print("subject_tokens=====",subject_tokens)
print("style_prompt=====",style_prompt)
print("scene_prompt1=====",scene_prompt1)
payload = {
"mode": 'init',
"subject_prompt": subject_prompt,
"subject_tokens": subject_tokens,
"subject_seed": 43,
"style_prompt": style_prompt,
"scene_prompt1": scene_prompt1,
"scene_prompt2": scene_prompt1,
"negative_prompt": "",
"cfg_scale": 5,
"same_initial_noise": False
}
response = requests.post(invoke_url, headers=headers, json=payload)
response.raise_for_status()
data = response.json()
current_time = datetime.now()
img_base64 = data['artifacts'][0]["base64"]
img_bytes = base64.b64decode(img_base64)
fileName = f"imgs/{current_time}.jpg"
with open(fileName, "wb") as f:
f.write(img_bytes)
# 返回图片链接,就是src
return "https://abc/files/myWorkSapce/wordk1/" + fileName + "?_xsrf=2%7Cc6e76894%7Cfa2f2f15a513717ec5fe62cb04591a57%7C1728806376"
client = OpenAI(
base_url = "https://integrate.api.nvidia.com/v1",
api_key = "你的apikey"
)
img_template = """分析用户是否有关于画图的意向,如果有画图意向,则根据用户的描述按照下面的 json 格式进行输出,输出内容一定要满足json格式,输出内容不要有其它冗余:
'{
"subject_prompt": 图片的主题(使用英文输出,一定要描述清楚,不要省略,比如“猫”),
"subject_tokens": 图像的主题描述词汇,输出为数组格式[词汇, 词汇],使用英文输出,
"style_prompt": 图像的风格,使用英文输出,
"scene_prompt1": 场景描述,使用英文输出
}'
"""
app = FastAPI()
# 配置 CORS 中间件
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # 允许所有来源
allow_credentials=True, # 允许传递凭证(如 cookies)
allow_methods=["*"], # 允许所有 HTTP 方法
allow_headers=["*"], # 允许所有头部
)
def need_draw(json_str):
# 解析 JSON 字符串
try:
json.loads(json_str)
return True
except json.JSONDecodeError:
return False
@app.get("/get_answer")
async def get_answer(question: str = Query(..., description="输入的问题")):
imageCompletion = client.chat.completions.create(
model="meta/llama3-70b-instruct",
messages=[{"role":"system","content":img_template},{"role":"user","content":question}],
temperature=0.5,
top_p=1,
max_tokens=1024,
stream=False
)
print("imageCompletion=====",imageCompletion)
if need_draw(imageCompletion.choices[0].message.content):
data = json.loads(imageCompletion.choices[0].message.content)
# 获取 style_prompt 字段
subject_prompt = data.get("subject_prompt")
subject_tokens = data.get("subject_tokens")
style_prompt = data.get("style_prompt")
scene_prompt1 = data.get("scene_prompt1")
return {"answer": draw_image(subject_prompt,subject_tokens,style_prompt,scene_prompt1), "isImage": True}
else:
completion = client.chat.completions.create(
model="meta/llama3-70b-instruct",
messages=[{"role":"system","content":"You are a helpful assistant, Communicate using Chinese."},{"role":"user","content":question}],
temperature=0.5,
top_p=1,
max_tokens=1024,
stream=True
)
answer = ""
for chunk in completion:
if chunk.choices[0].delta.content is not None:
answer += chunk.choices[0].delta.content
return {"answer": answer, "isImage": False}
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)
HTML模板
<!DOCTYPE html>
<html lang="zh-CN">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>聊天页面</title>
<style>
body {
font-family: Arial, sans-serif;
background-color: #f4f4f4;
margin: 0;
padding: 0;
display: flex;
justify-content: center;
align-items: center;
height: 100vh;
}
.chat-container {
width: 100%;
max-width: 600px;
background-color: #fff;
border-radius: 8px;
box-shadow: 0 0 10px rgba(0, 0, 0, 0.1);
overflow: hidden;
}
.chat-header {
background-color: #007bff;
color: #fff;
padding: 10px;
text-align: center;
font-size: 18px;
}
.chat-body {
padding: 20px;
height: 400px;
overflow-y: auto;
border-bottom: 1px solid #ddd;
}
.chat-message {
margin-bottom: 10px;
}
.chat-message.user {
text-align: right;
color: #007bff;
}
.chat-message.bot {
text-align: left;
color: #333;
}
.chat-input {
display: flex;
padding: 10px;
}
.chat-input input {
flex: 1;
padding: 10px;
border: 1px solid #ddd;
border-radius: 4px;
margin-right: 10px;
}
.chat-input button {
padding: 10px 20px;
background-color: #007bff;
color: #fff;
border: none;
border-radius: 4px;
cursor: pointer;
}
.chat-input button:hover {
background-color: #0056b3;
}
</style>
</head>
<body>
<div class="chat-container">
<div class="chat-header">聊天页面</div>
<div class="chat-body" id="chat-body">
<!-- 聊天记录将在这里显示 -->
</div>
<div class="chat-input">
<input type="text" id="user-input" placeholder="输入问题...">
<button onclick="sendMessage()">发送</button>
</div>
</div>
<script>
function sendMessage() {
const userInput = document.getElementById('user-input').value.trim();
if (userInput === '') {
alert('请输入问题');
return;
}
// 显示用户消息
const chatBody = document.getElementById('chat-body');
const userMessage = document.createElement('div');
userMessage.className = 'chat-message user';
userMessage.textContent = userInput;
chatBody.appendChild(userMessage);
chatBody.scrollTop = chatBody.scrollHeight;
// 清空输入框
document.getElementById('user-input').value = '';
// 发送请求获取回答
fetch(`http://localhost:8001/get_answer?question=${encodeURIComponent(userInput)}`)
.then(response => response.json())
.then(data => {
if (data.isImage) {
const botMessage = document.createElement('img');
botMessage.src = data.answer;
botMessage.className = 'chat-message bot';
botMessage.width = 300;
botMessage.height = 200;
chatBody.appendChild(botMessage);
chatBody.scrollTop = chatBody.scrollHeight;
} else {
const botMessage = document.createElement('div');
botMessage.className = 'chat-message bot';
botMessage.textContent = data.answer;
chatBody.appendChild(botMessage);
chatBody.scrollTop = chatBody.scrollHeight;
}
})
.catch(error => {
console.error('Error:', error);
const errorMessage = document.createElement('div');
errorMessage.className = 'chat-message bot';
errorMessage.textContent = '无法获取回答,请稍后再试。';
chatBody.appendChild(errorMessage);
chatBody.scrollTop = chatBody.scrollHeight;
});
}
</script>
</body>
</html>
运行应用
python chat.py
访问页面
总结
demo使用了模型meta / llama3-70b-instruct
和nvidia / consistory
,根据用户输入判断是否有作画意图,来判断是否调用文生图模型。