在亚马逊云科技上对Stable Diffusion模型提示词、输出图像内容进行安全审核

news2024/9/23 11:13:47

项目简介:

小李哥将继续每天介绍一个基于亚马逊云科技AWS云计算平台的全球前沿AI技术解决方案,帮助大家快速了解国际上最热门的云计算平台亚马逊云科技AWS AI最佳实践,并应用到自己的日常工作里。

本次介绍的是如何在亚马逊云科技机器学习托管服务SageMaker上部署开源大模型Stable Diffusion,利用亚马逊云科技Comprehend对模型输入提示词进行有害性检测,并利用亚马逊云科技Rekognition服务对生成图像内容进行有害性检测,构建负责任的AI防止大模型被滥用。本架构设计全部采用了云原生Serverless架构,提供可扩展和安全的AI解决方案。本方案的解决方案架构图如下:

方案所需基础知识   

什么是 Amazon SageMaker?

Amazon SageMaker 是亚马逊云科技提供的一站式机器学习服务,帮助开发者和数据科学家轻松构建、训练和部署机器学习模型。SageMaker 提供了全面的工具,从数据准备、模型训练到部署和监控,覆盖了机器学习项目的全生命周期。通过 SageMaker,用户可以加速机器学习模型的开发和上线,并确保模型在生产环境中的稳定性和性能。

什么是 Amazon Comprehend?

Amazon Comprehend 是亚马逊云科技提供的一项自然语言处理(NLP)服务,能够自动从文本中提取有价值的信息。通过机器学习技术,Comprehend 可以识别文本中的实体、情感、关键词、语言、主题等,帮助企业更好地理解和分析大量非结构化数据。它适用于客户反馈分析、内容分类、文档处理等场景,使得信息挖掘和数据洞察变得更加简单和高效。

什么是 Amazon Rekognition?

Amazon Rekognition 是亚马逊云科技提供的一项图像和视频分析服务。它使用深度学习技术来检测、识别和分析图像中的对象、场景、面部表情、文字等。Rekognition 可以应用于多种场景,如面部识别、内容审核、对象检测和人群统计等,帮助企业自动化处理图像和视频数据,提升效率并增强安全性。

什么是 Stable Diffusion?

Stable Diffusion 是一种先进的生成式 AI 模型,专门用于生成高质量的图像。通过扩散模型技术,Stable Diffusion 能够将简单的文本描述转化为逼真的图像。这个模型具有强大的生成能力,可以应用于艺术创作、广告设计、游戏开发等领域,为用户提供丰富的视觉内容生成工具。

为什么要对 Stable Diffusion 输入输出内容进行安全审核?

防止不当内容生成

Stable Diffusion 可以根据输入的文本生成图像,但如果输入的文本内容不当或恶意,可能会生成带有敏感、违法或不道德内容的图像。对输入输出内容进行审核,能够有效防止此类内容的生成和传播,确保模型的使用符合道德和法律标准。

保护用户隐私

在生成图像时,可能涉及到用户的私人信息或敏感数据。通过审核输入输出内容,可以确保这些信息不会被意外泄露或滥用,保护用户的隐私权。

遵守法律法规

各国对生成和传播图像内容有不同的法律规定。通过对内容进行审核,企业可以确保生成的图像符合所在国家或地区的法律法规,避免法律风险。

维护品牌声誉

对内容进行安全审核,有助于防止不符合公司价值观或可能损害品牌声誉的内容生成,从而维护品牌的形象和公众信任。

本方案包括的内容

1. 在SageMaker上部署开源大模型Stable Diffusion

2. 在SageMaker上调用Stable Diffusion模型API生成图片

3. 将Stable Diffusion模型API节点集成到云端应用上

4. 评估大模型输入问题的有害性

5. 对大模型输出图片进行安全审核

项目搭建具体步骤:

1. 打开亚马逊云科技控制台,进入Amazon SageMaker服务主页,点击Open Studio进入模型开发环境。

2. 创建一个新的Jupyte NoteBook文件,复制以下代码安装必要依赖并指明Stable Diffusion模型ID。

%pip install --upgrade sagemaker --quiet
model_id = "model-imagegeneration-stabilityai-stable-diffusion-xl-base-1-0"

3. 运行以下代码列举出JumpStart中,可以快速部署的用于生成图片的所有Stable Diffusion大模型

import IPython
from ipywidgets import Dropdown
from sagemaker.jumpstart.notebook_utils import list_jumpstart_models
from sagemaker.jumpstart.filters import And


filter_value = And("task == imagegeneration")
ss_models = list_jumpstart_models(filter=filter_value)

dropdown = Dropdown(
    value=model_id,
    options=ss_models,
    description="Sagemaker Pre-Trained Image Generation Models:",
    style={"description_width": "initial"},
    layout={"width": "max-content"},
)
display(IPython.display.Markdown("## Select a pre-trained model from the dropdown menu"))
display(dropdown)

4. 运行以下代码开始部署Stable Diffusion大模型。

# Deploy the model
from sagemaker.jumpstart.model import JumpStartModel
from sagemaker.serializers import JSONSerializer
import time

# The model is deployed on an ml.g5.4xlarge instance. To see all the supported parameters by the JumpStartModel
# class use this link - https://sagemaker.readthedocs.io/en/stable/api/inference/model.html#sagemaker.jumpstart.model.JumpStartModel
my_model = JumpStartModel(model_id=dropdown.value)
predictor = my_model.deploy()
# Wait for a few seconds so model the is properly loaded.
time.sleep(60)

5. 运行以下代码,导入调用大模型的必要依赖,配置图片生成请求参数,这里我们的图片生成提示词为”生成一个亚马逊雨林中的美洲虎图片“。同时我们定一个图片解码函数”decode_and_show“用于显示生成的图片,最后调用图片生成API "Predictor.predict()"生成图片。

from PIL import Image
import io
import base64
import json
import boto3
from typing import Union, Tuple
import os

payload = {
    "text_prompts": [{"text": "jaguar in the Amazon rainforest"}],
    "width": 1024,
    "height": 1024,
    "sampler": "DPMPP2MSampler",
    "cfg_scale": 7.0,
    "steps": 50,
    "seed": 133,
    "use_refiner": True,
    "refiner_steps": 40,
    "refiner_strength": 0.2,
}

def decode_and_show(model_response) -> None:
    """
    Decodes and displays an image from SDXL output

    Args:
        model_response (GenerationResponse): The response object from the deployed SDXL model.

    Returns:
        None
    """
    image = Image.open(io.BytesIO(base64.b64decode(model_response)))
    display(image)
    image.close()


response = predictor.predict(payload)
# If you get a time out error, check the endpoint logs in Amazon CloudWatch for the model loading status
# and invoke it again.
decode_and_show(response["generated_image"])

我们在开发环境里可以看到大模型生成的图片内容。

 6. 接下来我们进入到无服务器计算服务Lambda中,创建一个函数”check_toxicity_function“,用于调用Amazon Comprehend服务的API,模型检测输入文字的有害性并返回到客户端。我们复制以下代码到Lambda函数中

import json
import boto3
import os

comprehend = boto3.client('comprehend')
THRESHOLD = float(os.environ['THRESHOLD'])

def check_toxicity(text_prompts):
    detected_labels = []
    for prompt in text_prompts:
        response = comprehend.detect_toxic_content(
            TextSegments=[
                {
                    "Text": prompt['text']
                }
            ],
            LanguageCode='en'
        )
        labels = response['ResultList'][0]['Labels']
        # DIY section
        # Replace l['Name'] with {l['Name']:l['Score']} so that detected
        # is an array of json objects
        detected = [l['Name']for l in labels if l['Score'] > THRESHOLD]
        if detected:
            detected_labels.extend(detected)
    return detected_labels

def lambda_handler(event, context):
    print("event is ", json.dumps(event))
    try:
        text_prompts = [json.loads(event['body'].strip('"'))]
        detected_labels = check_toxicity(text_prompts)
        if detected_labels:
            return {
                'statusCode': 200,
                'headers': {
                    'Content-Type': 'application/json',
                    'Access-Control-Allow-Headers': 'Content-Type',
                    'Access-Control-Allow-Origin': '*',
                    'Access-Control-Allow-Methods': 'OPTIONS,POST'
                },
                'body': json.dumps({'detected_labels': detected_labels})
            }
        else:
            return {
                'statusCode': 200,
                'headers': {
                    'Content-Type': 'application/json',
                    'Access-Control-Allow-Headers': 'Content-Type',
                    'Access-Control-Allow-Origin': '*',
                    'Access-Control-Allow-Methods': 'OPTIONS,POST'
                },
                'body': json.dumps({'detected_labels': 'non-toxic content and safe to proceed'})
            }
    except Exception as e:
        print(f"Error: {e}")
        return {
            'statusCode': 500,
            'headers': {
                'Content-Type': 'application/json',
                'Access-Control-Allow-Headers': 'Content-Type',
                'Access-Control-Allow-Origin': '*',
                'Access-Control-Allow-Methods': 'OPTIONS,POST'
            },
            'body': json.dumps({'error': 'An error occurred while processing the request'})
        }

7. 我们再建一个新的Lambda函数”classifier_lambda_function“,调用Amazon Rekognition服务API对Stable Diffusion生成的图片进行内容审核。复制以下代码到Lambda中。

import io
import base64
import json
import boto3
import os
import uuid
import ast


comprehend = boto3.client('comprehend')
sagemaker_runtime = boto3.client("runtime.sagemaker")
rekognition = boto3.client('rekognition')
s3_client = boto3.client('s3')
s3 = boto3.resource('s3')

ENDPOINT_NAME = os.environ["ENDPOINT_NAME"]
bucket_name = os.environ['BUCKET_NAME']
THRESHOLD = 0.2
s3_folder = 'generated_images/'



def query_endpoint(prompt):
    response = sagemaker_runtime.invoke_endpoint(
        EndpointName=ENDPOINT_NAME, ContentType="application/json", Body=json.dumps(prompt,separators=(',', ':')).encode("utf-8")
    )
    print("response is ",response)
    result = json.loads(response["Body"].read().decode())
    return result
    
    
def detect_moderation(img_bytes):
    confidence_data = [ ]
    response = rekognition.detect_moderation_labels(
        Image={
        'Bytes': base64.b64decode(img_bytes)
        })
    for label in response['ModerationLabels']:
        confidence = label['Name'] + ' : ' + str(label['Confidence'])
        print (label['Name'] + ' : ' + str(label['Confidence']))
        print("confidence is ", confidence)
        confidence_data.append(confidence + "\n")
    
    return confidence_data

def lambda_handler(event,context):
    print("event is ",json.dumps(event))
    pm_str=json.loads(event["body"].strip('"'))
    prompt = {
        "text_prompts":  [(pm_str)],
        }
    print(prompt)
    response = query_endpoint(prompt)
    if "generated_image" in response:
        image_data = response["generated_image"]
        confLevel = detect_moderation(image_data)
        print(confLevel, len(confLevel))
        
        if len(confLevel) > 0:
            return {
                    'statusCode': 400,
                    'headers': {
                        'Access-Control-Allow-Headers': 'Content-Type',
                        'Access-Control-Allow-Origin': '*',
                        'Access-Control-Allow-Methods': 'OPTIONS,POST'
                    },
                    'body': json.dumps(confLevel)
                }
        else:
            imageBytes = io.BytesIO(base64.b64decode(image_data))
            file_name = f'generated-image-{uuid.uuid4()}.jpg'
    
            s3_client.upload_fileobj(
                imageBytes,
                bucket_name,
                f'{s3_folder}{file_name}',
                ExtraArgs={'ContentType': 'image/jpeg'}
            )
            return {
                    'statusCode': 200,
                    'headers': {
                        'Content-Type': 'image/png',
                        'Access-Control-Allow-Headers': 'Content-Type',
                        'Access-Control-Allow-Origin': '*',
                        'Access-Control-Allow-Methods': 'OPTIONS,POST'
                    },
                    'body': json.dumps(file_name),
                    'isBase64Encoded': True
            }
    else:
        return {
            'statusCode': 400,
            'headers': {
                'Access-Control-Allow-Headers': 'Content-Type',
                'Access-Control-Allow-Origin': '*',
                'Access-Control-Allow-Methods': 'OPTIONS,POST'
            },
            'body': json.dumps({'error': 'Response is not in the expected format'})
    }


       


8. 接下来我们为Lambda函数前面添加一个API Gateway,作为API管理服务并提供对外暴露的API端点,在该服务中我们定义不同的HTTP方法、路径,绑定不同的Lambda函数来管理API。

如使用POST方法调用路径/classifier时,我们触发Lambda函数:”classifier_lambda_function“。使用POST方法调用路径/classifier/checkToxicity时,我们出发函数:”check_toxicity_function“。

同时API Gateway服务提供了端点URL供用户访问。

9. 本架构中我们使用到了CloudFront对API和网页请求进行加速,我们进入CloudFront服务页面中,复制并打开URL。

10. 首先我们对提示词文字进行检测,我们输入问题得到了回复”提示词包含侮辱性词汇“。

11. 我们再在相同界面中输入”生成一个晴朗的一天“,该提示词通过了文字有害性检测,生成的图片也通过安全检查,成功显示在生成界面中。

以上就是在亚马逊云科技上利用亚马逊云科技上利用Amazon Sagemaker部署Stable Diffusion模型,并对输入提示词和输出图像内容进行安全审核,的全部步骤。欢迎大家未来与我一起,未来获取更多国际前沿的生成式AI开发方案。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2050714.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

HighPoint SSD7749M2:128TB NVMe 存储卡实现28 GB/s高速传输

HighPoint Technologies推出了一款全新的SSD7749M2 RAID卡,能够在标准的桌面工作站中安装多达16个M.2 SSD,实现高达128TB的闪存存储。该卡通过PCIe Gen4 x16接口提供高达28 GB/s的顺序读写性能。这些令人瞩目的性能规格伴随着高昂的价格标签。 #### 技术…

ArcGIS Pro基础:设置快速访问工具栏

上图【红色框线】内显示就是快速访问工具栏,访问非常方便,不需要切换到选项卡了 上图显示,可以勾选或者取消进行设置,通过【更多命令】可以选择更多的工具 如上图所示,可以选择自己经常使用的命令,可以输入…

手撕线程池

1.手撕线程池原理图 2.代码实现 // 手撕线程池 public class Main {public static void main(String[] args) {ThreadPool threadPool new ThreadPool(1,1000,TimeUnit.MILLISECONDS,1,(queue, task) -> {queue.putByTime(task,1500,TimeUnit.MILLISECONDS);});for (int i…

LangChain 实战演练:借助 LangChain SQL Agent 与 GPT 实现文档智能分析及交互

LangChain实战:利用LangChain SQL Agent和GPT进行文档分析和交互 我最近接触到一个非常有趣的挑战,涉及到人工智能数字化大量文件的能力,并使用户可以在这些文件上提出复杂的与数据相关的问题,比如: 数据检索问题&…

【qt】基于tcp的消息发送

我们需要实现客户端发消息,服务端接收消息 服务端界面新增接收消息 实现客户端发送和清空 发送数据需要将发送栏的信息转化为QByteArray,然后使用socket的write发送过去 实现服务端的接收 效果演示 20240818_111603 代码展示 server Widget.h #ifndef WIDGET_H …

Java的File类与IO流

目录 1. java.io.File类的使用 1.1 概述 1.2 构造器 1.3 常用方法 1、获取文件和目录基本信息 2、列出目录的下一级 3、File类的重命名功能 4、判断功能的方法 5、创建、删除功能 1.4 练习 2. IO流原理及流的分类 2.1 Java IO原理 2.2 流的分类 2.3 流的API 3. …

如何在 Windows/Mac/在线/iPhone/Android 上将 PDF 转换为 Word

PDF(便携式文档格式)是一种流行的格式,广泛用于在数字电子设备中呈现文档。输出文件小且兼容性强,使 PDF 如此受欢迎。但是,编辑 PDF 文件并非免费。您无需购买 PDF 编辑器,而是可以将 PDF 转换为 Word 进行…

「OC」NSPredicate —— 使用谓词过滤元素

「OC」NSPredicate —— 使用谓词过滤元素 文章目录 「OC」NSPredicate —— 使用谓词过滤元素前言介绍常见用法**比较运算符****逻辑运算符****字符串比较运算符****聚合运算符****用于字典或者类当中****格式说明符(占位符)** 实际运用总结参考文章 前…

05创建型设计模式——原型模式

一、原型模式简介 原型模式(Prototype Pattern)模式是一种对象创建型模式,它采取复制原型对象的方法来创建对象的实例。使用原型模式创建的实例,具有与原型一样的数据。 1)由原型对象自身创建目标对象。换句话说&…

python基础语法 010 类和对象-3 方法

1.3 方法 属性表示是一个类当中的成员或类的特征,而方法是?? 方法:表示类、对象的行为,方法本质上是函数,是一个特殊的函数 属性名称一般为名词,方法名称一般为动词 1.3.1 方法 VS 属性 1、…

24/8/17算法笔记 DDPG算法

深度确定性策略梯度(DDPG)算法是一种用于解决连续动作空间强化学习问题的算法。它结合了确定性策略梯度(DPG)和深度学习技术的优点,通过Actor-Critic框架进行策略和价值函数的近似表示。DDPG算法的关键组成部分包括经验…

【RAG综述】北京大学检索增强技术综述

RAG for AIGC ​ 图 1 描述了一个典型的 RAG 过程。给定一个输入查询,检索器识别相关的数据源,检索到的信息与生成器交互以改进生成过程。根据检索结果如何增强生成,有几种基础范式(简称基础):它们可以作为…

STM32的蜂鸣器

蜂鸣器分为有源蜂鸣器和无源蜂鸣器。 有源蜂鸣器:内部有震荡源,只要通电即可自动发出固定频率的声音。(频率固定无 法控制音色) 。 无源蜂鸣器:内部无震荡源,需要外部脉冲信号驱动发声,声音频…

《机器学习》 线性回归 一元、多元 推导 No.3

一、什么是线性回归 线性回归是一种用于预测连续数值的机器学习算法。它基于输入特征与目标变量之间的线性关系建立了一个线性模型。线性回归的目标是找到最佳拟合直线,以最小化预测值与实际值之间的误差。这个线性模型可以用来进行预测和推断。 线性回归的模型可以…

SpringBoot Profile多环境配置及配置优先级

【SpringBoot学习笔记 三】Profile多环境配置及配置优先级_profiles队列中的优先值-CSDN博客 Profile激活方式 但是我们发现一个问题,就是每次切换环境还需要去配置里指定,然后通过修改dev为test或prod来切换项目环境 , 这样做的话每次切换环境都要重新改…

前端面试——如何判断对象和数组

给你一个值,如何判断其是对象还是数组??? 我们先给出数据 var lists [1,2,3,4,5]var objs {length:5 } 我们分别尝试如下五种方法 console.log((✘)使用length,lists.length,objs.length); console.log((✔)使用isArray,Arr…

【已成功EI检索】第三届机电一体化技术与航空航天工程国际学术会议(ICMTAE 2023)

重要信息 大会官网:www.icmtae.org 大会时间:2023年9月15-17日 大会地点:中国-江西南昌理工学院(南昌市青山湖区经济技术开发区英雄大道901号) 接受/拒稿通知:投稿后1周内 收录检索:EI 和 …

Vulkan 学习(4)---- Vulkan 逻辑设备

目录 Vulkan Logical Device OverView逻辑设备创建VkDeviceQueueCreateInfoDeviceExtension获取DeviceQueue参考代码 Vulkan Logical Device OverView 在 Vulkan 中,逻辑设备(Logical Device)是与物理设备(Physical Device)交互的接口,它抽象了对特定GPU(物理设备)…

CDD数据库文件制作(八)——服务配置(0x85)

目录 1.子功能创建2.会话切换配置/安全等级配置2.1.根据诊断调查表进行信息提取2.2.会话转换配置/安全等级配置3.寻址方式信息提取/禁止肯定响应位(SPRMIB)信息3.1.寻址方式/禁止肯定响应位(SPRMIB)配置4.否定响应码信息提取4.1.否定响应码配置按照诊断调查表中对0x85服务的…

PX30 Android8.1适配AIC8800 wifi

wifi驱动生成ko文件 生成后 通过wpa_supplicant加载参数 external/wpa_supplicant_8/wpa_supplicant/main.c int main(int argc, char *argv[]) {int ret -1;char module_type[20]{0};wpa_printf(MSG_INFO,"argc %d\n",argc);if(argc 2) {if (wifi_type[0] 0) …