1.数据集下载
数据集下载链接:https://hyper.ai/datasets/33096
2. 数据集格式转换
需要将json中的标注信息转换为yolo格式的标注文件数据
import json
import os
import shutil
import cv2
import matplotlib.pyplot as plt
target = "./data/val"
def convert(size, box):
dw = size[1]
dh = size[0]
# box x1 y1 x2 y2
x = (box[0] + box[2]) / 2.0
y = (box[1] + box[3]) / 2.0
w = box[2] - box[0]
h = box[3] - box[1]
x = x / dw
w = w / dw
y = y / dh
h = h / dh
if w >= 1:
w = 0.99
if h >= 1:
h = 0.99
return (x, y, w, h)
# 将标注数据转换为yolo格式
with open(target+"/_annotations.coco.json") as f:
anno = json.load(f)
images = {}
labels = {}
for img in anno['images']:
images[img["id"]] = img["file_name"]
for an in anno['annotations']:
labels[an["image_id"]] = an
print(anno)
img_dir = target+"/images/"
anno_dir = target+"/labels/"
if (not os.path.exists(img_dir)):
os.mkdir(img_dir)
os.mkdir(anno_dir)
for i in images:
# 将图片复制到images文件夹
shutil.copyfile(target+"/"+ images[i], img_dir+"/"+ images[i])
img = cv2.imread(img_dir + "/" + images[i])
# 生成标注文件
label = labels[i]
filename,_ = os.path.splitext(images[i])
with open(anno_dir+"/"+ filename+ ".txt","w") as f:
box = label["bbox"]
# img = cv2.rectangle(img,(box[0],box[1]),(box[0]+box[2],box[1] + box[3]),(50,50,50),2)
# plt.imshow(img,)
# plt.show()
box = convert(img.shape, (box[0],box[1],box[0]+box[2],box[1] + box[3]))
f.write(str(label["category_id"])+" " + " ".join([str(a) for a in box]))
将test、train和val都 转换一下
3. 模型训练
数据配置文件
# 数据集所在路径
path: C:\Users\lhq\Desktop\Wildfire-Smoke\data
train: "./train/"
val: "./val/"
test: "./test/"
nc: 2
names:
0: 烟雾
1: 烟雾
训练代码
from ultralytics import YOLO
from ultralytics.utils import DEFAULT_CFG
from datetime import datetime
current_time = datetime.now()
time_str = current_time.strftime("%Y-%m-%d_%H-%M-%S")
# 训练结果保存路径
DEFAULT_CFG.save_dir = f"./models/{time_str}"
if __name__ == "__main__":
model = YOLO("yolov8n.pt")
# Train the model
results = model.train(data="smoke.yaml", epochs=100, imgsz=640, device=0, save=True)
4. 模型测试
预测代码:
from ultralytics import YOLO
# Load a model
model = YOLO('best.pt')
# Run batched inference on a list of images
model.predict("./demo/", imgsz=640, save=True, device=0,plots=True)