NSFW 图片分类

news2024/11/23 1:47:51

NSFW指的是 不适宜工作场所(“Not Safe (or Suitable) For Work;”)。在本文中,将介绍如何创建一个检测NSFW图像的图像分类模型。

数据集

由于数据集的性质,我们无法从一些数据集的网站(如Kaggle等)获得所有图像。

但是我们找到了一个专门抓取这种类型图片的github库,所以我们可以直接使用。clone项目后可以运行下面的代码来创建文件夹,并将每个图像下载到其特定的文件夹中。

 folders = ['drawings','hentai','neutral','porn','sexy']
 urls = ['urls_drawings.txt','urls_hentai.txt','urls_neutral.txt','urls_porn.txt','urls_sexy.txt']
 names = ['d','h','n','p','s']
 
 for i,j,k in zip(folders,urls,names):
     try:
         #Specify the path of the  folder that has to be made
         folder_path = os.path.join('your directory',i)
         os.mkdir(folder_path)
     except:
         pass
     #setup the path of url text file
     url_path = os.path.join('Datasets_Urls',j)
     my_file = open(url_path, "r")
     data = my_file.read()
     #create a list with all urls
     data_into_list = data.split("\n")
     my_file.close()
     icount = 0
     for ii in data_into_list:
         try:
             #create a unique image names for each images
             image_name = 'image'+str(icount)+str(k)+'.png'
             image_path = os.path.join(folder_path,image_name)
             #download it using the library
             urllib.request.urlretrieve(ii, image_path)
             icount+=1
         except Exception as e:
             pass
         #this below code is done to make the count of the image same for all the data 
         #you can use a big number if you are building a more complex model or if you have a good system
         if icount == 2000:
             break

这里的folder变量表示类的名称,urls变量用于获取URL文本文件(可以根据文本文件名更改它),name变量用于为每个图像创建唯一的名称。

上面代码将为每个类下载2000张图像,可以编辑最后一个“if”条件来更改下载图像的个数。

数据准备

我们下载的文件夹可能包含其他类型的文件,所以首先必须删除不需要的类型的文件。

 image_exts = ['jpeg','.jpg','bmp','png']
 path_list = ['drawings','hentai','neutral','porn','sexy']
 cwd = os.getcwd()
 def remove_other_images(path_list):
     for ii in path_list:
         data_dir = os.path.join(cwd,'DataSet',ii)
         for image in os.listdir(os.path.join(data_dir)):
             image_path = os.path.join(data_dir,image_class,image)
             try:
                 img = cv2.imread(image_path)
                 tip = imghdr.what(image_path)
                 if tip not in image_exts:
                     print('Image not in ext list {}'.format(image_path))
                     os.remove(image_path)
             except Exception as e:
                 print("Issue with image {}".format(image_path))
 remove_other_images(path_list)

上面的代码删除了扩展名不是指定格式的图像。

另外图像可能包含许多重复的图像,所以我们必须从每个文件夹中删除重复的图像。

 cwd = os.getcwd()
 path_list = ['drawings','hentai','neutral','porn','sexy']
 def remove_dup_images(path_list):
     for ii in path_list:
         os.chdir(os.path.join(cwd,'DataSet',ii))
         filelist = os.listdir()
         duplicates = []
         hash_keys = dict()
         for index, filename in enumerate(filelist):
             if os.path.isfile(filename):
                 with open(filename,'rb') as f:
                     filehash = hashlib.md5(f.read()).hexdigest()
                 if filehash not in hash_keys:
                     hash_keys[filehash] = index
                 else:
                     duplicates.append((index,hash_keys[filehash]))
             
         for index in duplicates:
             os.remove(filelist[index[0]])
             print('{} duplicates removed from {}'.format(len(duplicates),ii))
 remove_dup_images(path_list)

这里我们使用hashlib.md5编码来查找每个类中的重复图像。

Md5为每个图像创建一个唯一的哈希值,如果哈希值重复(重复图像),那么我们将重复图片添加到一个列表中,稍后进行删除。

因为使用TensorFlow框架所以需要判断是否被TensorFlow支持,所以我们这里加一个判断:

 import tensorflow as tf
 
 os.chdir('{data-set} directory')
 cwd = os.getcwd()
 
 for ii in path_list:
     os.chdir(os.path.join(cwd,ii))
     filelist = os.listdir()
     for image_file in filelist:
         with open(image_file, 'rb') as f:
             image_data = f.read()
 
         # Check the file format
         _, ext = os.path.splitext(image_file)
         if ext.lower() not in ['.jpg', '.jpeg', '.png', '.gif', '.bmp']:
             print('Unsupported image format:', ext)
             os.remove(os.path.join(cwd,ii,image_file))            
         else:
             # Decode the image
             try:
                 image = tf.image.decode_image(image_data)
             except:
                 print(image_file)
                 print("unspported")
                 os.remove(os.path.join(cwd,ii,image_file))

以上就是数据准备的所有工作,在清理完数据后,我们可以拆分数据。比如分割创建一个训练、验证和测试文件夹,并手动添加文件夹中的图像,我们将80%用于训练,10%用于验证,10%用于测试。

模型

首先导入tensorflow

 import tensorflow as tf
 import os
 import numpy as np
 import matplotlib.pyplot as plt
 from sklearn.utils import shuffle
 import hashlib
 from imageio import imread
 import numpy as np
 from tensorflow.keras.preprocessing.image import ImageDataGenerator
 from tensorflow.keras.applications.vgg16 import VGG16
 from tensorflow.keras.applications.vgg16 import preprocess_input
 from tensorflow.keras.layers import Flatten,Dense,Input
 from tensorflow.keras.models import Model,Sequential
 from keras import optimizers

对于图像,默认大小设置为224,224。

 IMAGE_SIZE = [224,224]

可以使用ImageDataGenerator库,进行数据增强。数据增强也叫数据扩充,是为了增加数据集的大小。ImageDataGenerator根据给定的参数创建新图像,并将其用于训练(注意:当使用ImageDataGenerator时,原始数据将不用于训练)。

 train_datagen = ImageDataGenerator(
         rescale=1./255,
         preprocessing_function=preprocess_input,
         rotation_range=40,
         width_shift_range=0.2,
         height_shift_range=0.2,
         shear_range=0.2,
         zoom_range=0.2,
         horizontal_flip=True,
         fill_mode='nearest')

对于测试集也是这样:

 test_datagen = ImageDataGenerator(rescale=1./255)

为了演示,我们直接使用VGG模型

vgg = VGG16(input_shape=IMAGE_SIZE+[3],weights='imagenet',include_top=False

然后冻结前面的层:

for layer in vgg.layers:
    layer.trainable = False

最后我们加入自己的分类头:

x = Flatten()(vgg.output)
prediction = Dense(5,activation='softmax')(x)
model = Model(inputs=vgg.input, outputs=prediction)
model.summary()

模型是这样的:

训练

看看我们训练集:

train_set = train_datagen.flow_from_directory('DataSet/train',
                                              target_size=(224,224),
                                              batch_size=32,
                                              class_mode='sparse')

验证集

val_set = train_datagen.flow_from_directory('DataSet/validation',
                                              target_size=(224,224),
                                              batch_size=32,
                                              class_mode='sparse')

使用’ sparse_categorical_crossentropy '损失,这样可以将标签编码为整数而不是独热编码。

from tensorflow.keras.metrics import MeanSquaredError
from tensorflow.keras.metrics import CategoricalAccuracy
adam = optimizers.Adam()
model.compile(loss='sparse_categorical_crossentropy',
              optimizer=adam,
              metrics=['accuracy',MeanSquaredError(name='val_loss'),CategoricalAccuracy(name='val_accuracy')])

然后就可以训练了:

from datetime import datetime
from keras.callbacks import ModelCheckpoint

log_dir = 'vg_log'

tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir = log_dir)

start = datetime.now()

history = model.fit_generator(train_set,
                              validation_data=val_set,
                              epochs=100,
                              steps_per_epoch=len(train_set)// batch_size,
                              validation_steps=len(val_set)//batch_size,
                              callbacks=[tensorboard_callback],
                             verbose=1)

duration = datetime.now() - start
print("Time taken for training is ",duration)

模型训练了100次。得到了80%的验证准确率。f1得分为93%

预测

下面的函数将获取一个图像列表并根据该列表进行预测。

import numpy as np
import matplotlib.image as mpimg
import matplotlib.pyplot as plt
from scipy.ndimage import gaussian_filter
def print_classes(images,model):
    classes = ['Drawing','Hentai','Neutral','Porn','Sexual']
    fig, ax = plt.subplots(ncols=len(images), figsize=(20,20))
    for idx,img in enumerate(images):
        img = mpimg.imread(img)
        resize = tf.image.resize(img,(224,224))
        result = model.predict(np.expand_dims(resize/255,0))
        result = np.argmax(result)
        if classes[result] == 'Porn':
            img = gaussian_filter(img, sigma=6)
        elif classes[result] == 'Sexual':
            img = gaussian_filter(img, sigma=6)
        elif classes[result] == 'Hentai':
            img = gaussian_filter(img, sigma=6)
        ax[idx].imshow(img)
        ax[idx].title.set_text(classes[result])

li = ['test1.jpeg','test2.jpeg','test3.jpeg','test4.jpeg','test5.jpeg']
print_classes(li,model)

看结果还是可以的。

最后,本文的源代码:

https://avoid.overfit.cn/post/8f681841d02e4a8db7bcf77926e123f1

作者:Nikhil Thalappalli

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/550132.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

少儿编程 中国电子学会图形化编程等级考试Scratch编程四级真题解析(选择题)2023年3月

2023年3月scratch编程等级考试四级真题 选择题(共25题,每题2分,共50分) 1、编写一段程序,从26个英文字母中,随机选出10个加入列表a。空白处应填入的代码是 A、 B、 C、 D、 答案:C

[CTF/网络安全] 攻防世界 simple_php 解题详析

[CTF/网络安全] 攻防世界 simple_php 解题详析 代码解读PHP弱语言特性姿势参数a限制绕过参数b限制绕过 总结 题目描述:小宁听说php是最好的语言,于是她简单学习之后写了几行php代码。 代码解读 $a$_GET[a]; 从HTTP GET请求参数中获取一个名为a的变量&#xff0c…

协同过滤算法的召回率、准确率、覆盖率、新颖度

python版计算协同过滤推荐算法的召回率、准确率、覆盖率、新颖度 推荐算法网站示例Demo 点我跳转图书管理推荐系统 点我跳转课程推荐系统 点我跳转电影推荐系统 1、召回率、准确率 2、覆盖率、新颖度 覆盖率反映了推荐算法发掘长尾的能力,覆盖率越高,说明推荐算法越能够将…

ChatGPT开始颠覆学习方式,应试教育面临哪些挑战?

ChatGPT爆火几个月,整个教育系统都在被颠覆。全球范围内,不少大学教授、系主任和管理人员,都在对课堂进行大规模的调整,以应对ChatGPT对教学活动造成的巨大冲击。 国内传统应试教育选出的分霸、考霸,是更能吃苦&#…

c++中的方法

c中的方法 static方法 与数据成员类似,方法有时会应用于全部对象而不是单个对象。可以编写static方法和数据成员。在方法声明前加上static即可。对于方法的定义前则不需要重复使用static关键字。 class Foo { public:static int sumFunc(int a, int b); };int Fo…

康耐视Visionpro工具-CogPMAlignTool为什么是最牛工具?

1.算法:有六种选项,分别是:PatMax,PatQuick, PatMax 与 PatQuick, PatFlex,PatMax-高灵敏度,透视 Patmax。 PatQuick 特点:速度最快,对于三维或者低质量原件最佳,承受更多图像差异; PatMax 特点:精确度最高,在二维元件上表现佳,最适合于细微细节; PatFlex 特点…

4. 通讯录实现的需求分析和架构设计

本文实现的是通讯录产品的需求分析和架构设计,重点在于结构层次的设计,方便代码阅读和维护。 一、通讯录实现的需求分析 1、通讯录的功能清单 添加一个人员打印显示所有人员删除一个人员查找一个人员保存文件加载文件 2,数据存储信息 人员…

[CTF/网络安全] 攻防世界 disabled_button 解题详析

[CTF/网络安全] 攻防世界 disabled_button 解题详析 input标签姿势disable属性总结 题目描述:X老师今天上课讲了前端知识,然后给了大家一个不能按的按钮,小宁惊奇地发现这个按钮按不下去,到底怎么才能按下去呢? input标…

Tiny+ 语言词法之C语言

访问【WRITE-BUG数字空间】_[内附完整源码和文档] 语义分析本质上就是在语法分析的基础上进一步完善分析的功能。举个例子来说,在语法分析部分的 if_stmt 函数中,在语义上判断条件必须返回布尔类型的值,因此我们加入一个判断,判断…

Unity之OpenXR+XR Interaction Toolkit示例Demo详解

一.前言 自从升级Unity版本到2021,然后使用OpenXR开发VR之后,我们整个团队的开发效率都提升了不少,这证明了不管什么领域,统一接口,统一规范都是必须的。 关于XR Interaction Toolkit插件,我已经写了几篇文章了,今天才想起来,最基础的Demo讲解还没有写,其实官方的这个…

chatgpt赋能Python-pythonfor循环5次

Python中for循环的使用方法及技巧 Python作为一种高级编程语言,其独特的语法结构和方便的操作方法受到了越来越多人的欢迎和喜爱。其中,for循环是Python程序员必备的基本技巧之一。在这篇文章中,我们将介绍Python中for循环的使用方法及技巧。…

HTTP协议【面试高频考点】

目录 一、HTTP 响应 1.首行 2.状态码(经典面试题,必考) 2.1 200 OK 2.2 404 Not Found 2.3 403 Forbidden 2.4 500 Internal Server Error 2.5 504 Gateway Timeout 2.6 302 Move temporarily 2.7 301 Moved Permanently 2.8 状态…

clearmymac4.13.5专业的Mac系统清理优化工具

CleanMyMac X是一款功能强大的Mac清理工具,它可以扫描您的Mac电脑,清除垃圾文件,卸载无用的应用程序,并优化系统性能。此外,CleanMyMac X还可以找到和修复Mac电脑上的许多其他问题,即使您不是技术专家也可以…

chatgpt赋能Python-pythona__a

Python中的aa 介绍 Python是一种流行的编程语言,具有简单易学和可读性强的特点。在Python中,常常使用aa这样的表达式,它表示将变量a的原始值加上它自己的值,然后将结果赋值给变量a。这种语法看起来很简单,但实际上有…

C语言函数大全-- _w 开头的函数(5)

C语言函数大全 本篇介绍C语言函数大全-- _w 开头的函数 1. _wspawnl 1.1 函数说明 函数声明函数功能int _wspawnl(int mode, const wchar_t* cmdname, const wchar_t* arglist, ...);启动一个新的进程并运行指定的可执行文件 参数: mode : 启动命令的…

【008】C++数据类型之重要关键字详解

C数据类型之重要关键字详解 引言一、const修饰普通变量重点说明 二、register修饰寄存器变量三、volatile强制访问内存四、sizeof测试类型的大小五、typedef关键字总结 引言 💡 作者简介:专注于C/C高性能程序设计和开发,理论与代码实践结合&a…

搭建go web 框架

思想base部分day1:封装gee封装context上下文封装前缀tree路由树分组封装group与中间件封装文件解析封装封装错误处理测试 思想 web框架服务主要围绕着请求与响应来展开的 搭建一个web框架的核心思想 1 便捷添加响应路径与响应函数(base) 2 能够接收多种数据类型传入(上下文cont…

第二章 表操作

一、数据表的设计理念 数据表是包括数据库所有数据的数据库对象,数据在表中的组织方式与在电子表格中相似,都是按行和列的格式组织的,其中每一行代表一条唯一的记录,每一列代表记录中的字段,表中的数据库对象包含列、…

Godot4节点树右键菜单添加自定义选项

前言 查看godot的源码推荐使用在线版vscode直接从github上看。(直接把网址的com改成dev即可) 重点查看以下源码 scene_tree_dock.h scene_tree_dock.cpp 开始 tool extends EditorPluginvar window var scene_menustatic func find_child_by_class(no…

OneDrive同步角标消失 - 解决方案

问题 在电脑端使用OneDrive时,文件管理器OneDrive文件夹内的文件会在左下角显示同步状态,如下图。若没有显示同步角标,则此功能出现异常,下文介绍如何显示同步角标。 值得一提的是,同步角标只起到显示作用&#xff0…