【人工智能学习之商品检测实战】
- 1 开发过程
- 2 网络训练效果
- 2.1 分割网络
- 2.2 特征网络
- 3 跟踪与后处理
- 4 特征库优化
- 5 项目源码解析
- 5.1 yolo训练
- train_yolo.py
- good_net.py
- dataset.py
- good_cls_data.py
- save_feature.py
- analyse_good.py
- shop_window.py
- test.py
- 6 结语
1 开发过程
- 拍摄并处理数据集
- 训练YOLOV8侦测商品
- 训练特征网络识别商品
- 商品跟踪与后处理
2 网络训练效果
2.1 分割网络
n效果较差
m和x耗时更长但效果并没有非常突出
最终选择:yolov8-s-seg
2.2 特征网络
对比ResNet50,ResNet101,MobileNet和Densnet121与Densnet169
Densnet169最为精准,在本数据集上提取特征能力和泛化性最强
最终选择:Densnet169
最终余弦相似度测试:95.7%
3 跟踪与后处理
判断新旧物体进行跟踪,裁剪所需内容。
4 特征库优化
读取特征库时按类别进行字典嵌套
5 项目源码解析
5.1 yolo训练
train_yolo.py
作用:训练测试yolo模型
from ultralytics import YOLO
from PIL import Image
import cv2
# 版本有问题可以进行以下尝试
# import matplotlib
# matplotlib.use('TkAgg')
def yolo_seg_train():
model = YOLO(model=r"D:\zhrdpy_project\AutShop\model\yolov8s-seg.pt")
model.train(data="D:/zhrdpy_project/AutShop/data/YOLODataset/dataset.yaml", epochs=100, batch=-1, workers=0, amp=False)
def yolo_train():
model = YOLO("D:/zhrdpy_project/AutShop/model/yolov8n.yaml")
model = YOLO("D:/zhrdpy_project/AutShop/runs/detect/train7/weights/best.pt")
model.train(data="D:/zhrdpy_project/AutShop/data/YOLODataset/dataset.yaml", epochs=10, batch=-1, workers=0)
# metrics = model.val()
def yolo_val():
# model = YOLO("D:/zhrdpy_project/AutShop/model/yolov8n.yaml")
model = YOLO(task='segment',model="D:/zhrdpy_project/AutShop/runs/segment/train_x/weights/best.pt")
metrics = model.val(
data='D:/zhrdpy_project/AutShop/data/YOLODataset/dataset.yaml',
# imgsz=, batch= ,conf,iou,max_det,half,device,dnn
save_json=True, # 将结果保存到 JSON 文件 # 默认False
save_hybrid=True, # 保存混合版本的标签(标签 + 其他预测)# 默认False
# plots=True, # 在训练期间显示绘图 # 默认False
rect=True, # 矩形 val,每批都经过整理,以实现最小的填充 # 默认False
) # no arguments needed, dataset and settings remembered
def yolo_test():
# 用自己训练好的权重用自己的"ultralytics-main1/runs/detect/train5/weights/best.pt
model = YOLO("D:/zhrdpy_project/AutShop/runs/segment/train_x/weights/best.pt")
# accepts all formats - image/dir/Path/URL/video/PIL/ndarray. 0 for webcam
# 零是摄像头,现在不用把他注释掉
# results = model.predict(source="0")
# source=用自己的验证图片绝对路径 , save=True保存
results = model.predict(source=r"D:\zhrdpy_project\AutShop\data\test.mp4", show=False, save=True)
# success = model.export(format="onnx")
if __name__ == '__main__':
# yolo_train()
# yolo_seg_train()
# yolo_test()
yolo_val()
good_net.py
作用:特征网络
import torchvision.models as models
from torch import nn
import torch
from torch.nn import functional as F
from good_cls_data import one_hot_size
import math
class Arcsoftmax(nn.Module):
def __init__(self, feature_num, cls_num):
super().__init__()
self.w = nn.Parameter(torch.randn((feature_num, cls_num)))
nn.init.kaiming_uniform_(self.w, a=math.sqrt(5)) # 更好的初始化方式,对张量 w 进行 Kaiming 均匀初始化
def forward(self, feature, m=0.5, s=1):
x = F.normalize(feature, dim=1)
w = F.normalize(self.w, dim=0)
cos = torch.matmul(x, w) / 10 #防止梯度爆炸 /10
a = torch.acos(cos)
top = torch.exp(s * torch.cos(a + m))
down = torch.sum(torch.exp(s * torch.cos(a)), dim=1, keepdim=True) - torch.exp(s * torch.cos(a)) + top
out = torch.log(top/down)
return out
'''
# 复杂但似乎没什么暖用的优化
def forward(self, x, s=1, m=0.5):
x_norm = F.normalize(x, dim=1)
w_norm = F.normalize(self.w, dim=0)
cos_theta = torch.matmul(x_norm, w_norm)
theta = torch.acos(cos_theta.clamp(-1 + 1e-7, 1 - 1e-7)) # 添加钳位防止acos溢出
cos_theta_m = cos_theta - m
idx = cos_theta > math.pi - m # 为防止溢出,对角度接近π的情况进行特殊处理
cos_theta_m[idx] = torch.cos(theta[idx] + m)
# 对于减法部分,需要广播以保持维度一致
adjustment = torch.where(idx.unsqueeze(1), torch.exp(s * torch.cos(theta.unsqueeze(1))), torch.tensor(0., device=cos_theta.device))
numerator = torch.exp(s * cos_theta_m)
denominator = numerator.sum(dim=1, keepdim=True) - adjustment.sum(dim=1, keepdim=True) + numerator
# 确保denominator非零
denominator = torch.clamp(denominator, min=1e-10)
arcface_loss = torch.log(numerator / denominator)
这里的修改主要是为了确保减法操作前后的张量在形状上是一致的。
通过使用unsqueeze(1)来增加维度,使得adjustment张量能够与numerator进行广播操作,
同时使用where函数来确保只有在idx标记为True的位置上才进行exp(s * cos(theta))的计算,
其他位置保持为0,这样可以正确地进行减法而不改变张量的形状。
此外,为了避免对数函数中分母为零的问题,添加了clamp函数来限制最小值。
return arcface_loss
'''
class GoodNet(nn.Module):
def __init__(self):
super(GoodNet, self).__init__()
self.nll_loss = nn.NLLLoss()
self.loss_fn = nn.CrossEntropyLoss()
# self.sub_net = nn.Sequential(
# models.densenet121(weights = models.DenseNet169_Weights.IMAGENET1K_V1)
# models.densenet169(weights = models.DenseNet169_Weights.IMAGENET1K_V1)
# models.densenet201(weights = models.DenseNet169_Weights.IMAGENET1K_V1)
# )
self.sub_net = models.densenet169(weights=None) # ***_new.pt
self.feature_net = nn.Sequential(
nn.BatchNorm1d(1000),
nn.LeakyReLU(0.1),
nn.Linear(1000, 512, bias=False),
)
self.arc_softmax = Arcsoftmax(512, one_hot_size)
def forward(self, x):
feature = self.feature_net(self.sub_net(x))
return feature, self.arc_softmax(feature)
def get_feature(self, x):
return self.feature_net(self.sub_net(x))
def getSoftmaxLoss(self, outputs, labels):
return self.nll_loss(outputs,labels)
if __name__ == '__main__':
net = GoodNet()
net.sub_net.load_state_dict(torch.load("weight/sub_net.pt"))
net.feature_net.load_state_dict(torch.load("weight/feature_net.pt"))
# net.load_state_dict(torch.load("weight/net.pt"))
print(net)
dataset.py
作用:特征网络训练测试训练集加载器
import glob
import os.path
import cv2
from PIL import Image
import torch
from torchvision import transforms
from torch.utils.data import Dataset
from good_net import one_hot_size
from good_cls_data import one_hot_dict
train_transform = transforms.Compose([
transforms.ToTensor(),
transforms.RandomHorizontalFlip(p=0.5), # 执行水平翻转的概率为0.5
transforms.RandomVerticalFlip(p=0.5), # 执行垂直翻转的概率为0.5
transforms.Resize((320, 320), antialias=True)
])
test_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Resize((320, 320), antialias=True)
])
class TrainDataset(Dataset):
def __init__(self,root=r"./data/CLSDataset_train", transform=train_transform):
super().__init__()
img_paths = glob.glob(os.path.join(root,"*","*","*"))
self.dataset = []
for path in img_paths:
label = path.rsplit('\\',maxsplit=2)[-2]
self.dataset.append((label,path))
self.transform = transform
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
label, img_path = self.dataset[idx]
img = Image.open(img_path)
img_tensor = self.transform(img)
one_hot = torch.zeros(one_hot_size)
one_hot[one_hot_dict[label]] = 1
one_hot_idx = one_hot.argmax()
return one_hot_idx,img_tensor
class TestDataset(Dataset):
def __init__(self,root=r"./data/CLSDataset_test", transform=test_transform):
super().__init__()
img_paths = glob.glob(os.path.join(root,"*","*"))
self.dataset = []
for path in img_paths:
label = path.rsplit('\\',maxsplit=2)[-2]
self.dataset.append((label,path))
self.transform = transform
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
label, img_path = self.dataset[idx]
img = Image.open(img_path)
img_tensor = self.transform(img)
one_hot = torch.zeros(one_hot_size)
one_hot[one_hot_dict[label]] = 1
one_hot_idx = one_hot.argmax()
return one_hot_idx,img_tensor
class TestDataset2(Dataset):
def __init__(self, root_dir=r"./data/CLSDataset_test2", transform=test_transform):
self.root_dir = root_dir
self.classes = sorted(os.listdir(root_dir)) # 获取所有分类目录
self.class_images = {cls: [os.path.join(cls, img) for img in os.listdir(os.path.join(root_dir, cls))] for cls in self.classes}
self.transform = transform
self.class_count = len(self.classes)
self.image_per_class = 10 # 每个分类的图片数量
def __len__(self):
return self.image_per_class # 每批25张总共10批
def __getitem__(self, index):
images_of_batch = []
labels_of_batch = []
for class_index in range(self.class_count):
class_name = self.classes[class_index]
img_path = self.class_images[class_name][index % self.image_per_class] # 确保循环取图
img = Image.open(os.path.join(self.root_dir, img_path))
if self.transform is not None:
img = self.transform(img)
images_of_batch.append(img)
labels_of_batch.append(class_index) # 类别ID可以直接用索引表示
# 如果想要返回一个批次的数据,可以将它们打包在一起
images_of_batch = torch.stack(images_of_batch) # 将图片列表转换为Tensor
labels_of_batch = torch.tensor(labels_of_batch) # 将类别ID列表转换为Tensor
return labels_of_batch,images_of_batch # 返回整个批次的图片和标签
if __name__ == '__main__':
'''
m = TestDataset()
one_hot,img_tensor = m[3]
img_path = r'D:\zhrdpy_project\AutShop\data\CLSDataset_train\bag\lao_mu_ji_tang_mian_dai_zhuang\3382.png'
img = Image.open(img_path)
img.show('1')
imgcv2 = cv2.imread(img_path)
cv2.imshow('2',imgcv2)
cv2.waitKey(0)
'''
# 初始化数据集
dataset = TestDataset2()
# 使用 DataLoader 进行批量加载
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False) # 注意,batch_size设为1是因为每次迭代已经包含了所有分类的图片
i = 0
# 遍历数据加载器进行测试
for labels ,images in dataloader:
# 在这里执行你的测试逻辑
print(labels)
i+=1
pass
print(i)
good_cls_data.py
作用:通用分类文本信息
one_hot_size = 25
one_hot_dict = {
'bai_shi_ke_le_guan_zhuang': 0,
'ke_kou_ke_le_guan_zhuang': 1,
'yong_chuang_tian_ya_guan_zhuang': 2,
'jian_li_bao_guan_zhuang': 3,
'le_shi_shu_pian_xiao_tong_huanggua_wei': 4,
'xue_bi_guan_zhuang': 5,
'qiao_ke_li_bing_gan_he_zhuang': 6,
'bing_lu_ping_zhuang': 7,
'le_shi_shu_pian_da_tong_huanggua_wei': 8,
'le_shi_shu_pian_da_tong_ning_meng_wei': 9,
'kang_shi_fu_hong_shao_niu_rou_mian_tong_zhuang': 10,
'bing_ling_qi_bu_ding_he_zhuang': 11,
'lao_tan_suan_cai_niu_rou_mian_tong_zhuang': 12,
'a_sa_mu_nai_cha_ping_zhuang': 13,
'mei_zhi_yuan_ping_zhuang': 14,
'xing_qiu_bei_yi_tong': 15,
'nong_fu_shan_quan_ping_zhuang': 16,
'le_shi_shu_pian_dai_zhuang_niu_pai_wei': 17,
'le_shi_shu_pian_dai_zhuang_huanggua_wei': 18,
'lao_mu_ji_tang_mian_dai_zhuang': 19,
'ou_de_cui_pian_dai_zhuang_jie_mo_wei': 20,
'mai_xiang_ji_wei_kuai_dai_zhuang': 21,
'wei_hua_bing_gan_he_zhuang': 22,
'O_pao_nai_he_zhuang': 23,
'chun_niu_nai_he_zhuang': 24,
}
# 乐事薯片小桶烧烤味:le_shi_shu_pian_xiao_tong_shaokao_wei,未注册参照样本test
# 老坛泡椒牛肉面桶装:lao_tan_pao_jiao_niu_rou_mian_tong_zhuang,未注册参照样本test
# 香蕉牛奶盒装:'xiang_jiao_niu_nai_he_zhuang',未注册参照样本test
# class names
# 百事可乐罐装,可口可乐罐装,勇闯天涯罐装,健力宝罐装,乐事薯片小罐黄瓜味,雪碧罐装,乐事薯片小罐烧烤味,冰露瓶装,乐事薯片大罐黄瓜味,乐事薯片大罐柠檬味,康师傅红烧牛肉面,老坛泡椒牛肉面,老坛酸菜牛肉面,阿萨姆奶茶瓶装,美汁源瓶装,星球杯一桶,农夫山泉瓶装,乐事薯片袋装牛排味,乐事薯片袋装黄瓜味,老母鸡汤面袋装,藕的脆片袋装芥末味,麦香鸡味块袋装,威化饼干盒装,O泡奶盒装,纯牛奶盒装,冰淇淋布丁盒装,巧克力饼干盒装,香蕉牛奶盒装,
name_dict = {
'百事可乐罐装':'bai_shi_ke_le_guan_zhuang',
'可口可乐罐装':'ke_kou_ke_le_guan_zhuang',
'勇闯天涯罐装':'yong_chuang_tian_ya_guan_zhuang',
'健力宝罐装':'jian_li_bao_guan_zhuang',
'乐事薯片小桶黄瓜味':'le_shi_shu_pian_xiao_tong_huanggua_wei',
'雪碧罐装':'xue_bi_guan_zhuang',
'乐事薯片小桶烧烤味':'le_shi_shu_pian_xiao_tong_shaokao_wei',
'冰露瓶装':'bing_lu_ping_zhuang',
'乐事薯片大桶黄瓜味':'le_shi_shu_pian_da_tong_huanggua_wei',
'乐事薯片大桶柠檬味':'le_shi_shu_pian_da_tong_ning_meng_wei',
'康师傅红烧牛肉面':'kang_shi_fu_hong_shao_niu_rou_mian_tong_zhuang',
'老坛泡椒牛肉面桶装':'lao_tan_pao_jiao_niu_rou_mian_tong_zhuang',
'老坛酸菜牛肉面桶装':'lao_tan_suan_cai_niu_rou_mian_tong_zhuang',
'阿萨姆奶茶瓶装':'a_sa_mu_nai_cha_ping_zhuang',
'美汁源瓶装':'mei_zhi_yuan_ping_zhuang',
'星球杯一桶':'xing_qiu_bei_yi_tong',
'农夫山泉瓶装':'nong_fu_shan_quan_ping_zhuang',
'乐事薯片袋装牛排味':'le_shi_shu_pian_dai_zhuang_niu_pai_wei',
'乐事薯片袋装黄瓜味':'le_shi_shu_pian_dai_zhuang_huanggua_wei',
'老母鸡汤面袋装':'lao_mu_ji_tang_mian_dai_zhuang',
'藕的脆片袋装芥末味':'ou_de_cui_pian_dai_zhuang_jie_mo_wei',
'麦香鸡味块袋装':'mai_xiang_ji_wei_kuai_dai_zhuang',
'威化饼干盒装':'wei_hua_bing_gan_he_zhuang',
'O泡奶盒装':'O_pao_nai_he_zhuang',
'纯牛奶盒装':'chun_niu_nai_he_zhuang',
'冰淇淋布丁盒装':'bing_ling_qi_bu_ding_he_zhuang',
'巧克力饼干盒装':'qiao_ke_li_bing_gan_he_zhuang',
'香蕉牛奶盒装':'xiang_jiao_niu_nai_he_zhuang'
}
cls_dict = {
'bai_shi_ke_le_guan_zhuang':'can',
'ke_kou_ke_le_guan_zhuang':'can',
'yong_chuang_tian_ya_guan_zhuang':'can',
'jian_li_bao_guan_zhuang':'can',
'le_shi_shu_pian_xiao_tong_huanggua_wei':'bucket',
'xue_bi_guan_zhuang':'can',
'le_shi_shu_pian_xiao_tong_shaokao_wei':'bucket',
'bing_lu_ping_zhuang':'bottle',
'le_shi_shu_pian_da_tong_huanggua_wei':'bucket',
'le_shi_shu_pian_da_tong_ning_meng_wei':'bucket',
'kang_shi_fu_hong_shao_niu_rou_mian_tong_zhuang':'bucket',
'lao_tan_pao_jiao_niu_rou_mian_tong_zhuang':'bucket',
'lao_tan_suan_cai_niu_rou_mian_tong_zhuang':'bucket',
'a_sa_mu_nai_cha_ping_zhuang':'bottle',
'mei_zhi_yuan_ping_zhuang':'bottle',
'xing_qiu_bei_yi_tong':'bucket',
'nong_fu_shan_quan_ping_zhuang':'bottle',
'le_shi_shu_pian_dai_zhuang_niu_pai_wei':'bag',
'le_shi_shu_pian_dai_zhuang_huanggua_wei':'bag',
'lao_mu_ji_tang_mian_dai_zhuang':'bag',
'ou_de_cui_pian_dai_zhuang_jie_mo_wei':'bag',
'mai_xiang_ji_wei_kuai_dai_zhuang':'bag',
'wei_hua_bing_gan_he_zhuang':'box',
'O_pao_nai_he_zhuang':'box',
'chun_niu_nai_he_zhuang':'box',
'bing_ling_qi_bu_ding_he_zhuang':'box',
'qiao_ke_li_bing_gan_he_zhuang':'box',
'xiang_jiao_niu_nai_he_zhuang':'box'
}
from torchvision import transforms
def img_transforms(image):
size_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Resize((320, 320), antialias=True)
])
return size_transform(image)
save_feature.py
作用:保存和读取特征文件
import torch
import glob
import cv2
import os
from good_cls_data import cls_dict
import good_net
from good_cls_data import img_transforms
import json
mydict={}
# 保存字典到文件
filepath = "dictionary.pt"
def add_dict_feature(feature,name,my_dict):
my_dict[name] = feature
print(f"{name}已写入字典")
def save_dict_feature(my_dict):
torch.save(my_dict, filepath)
print("已保存字典到文件:", filepath)
def load_dict_feature():
try:
loaded_dict = torch.load(filepath)
cls_goodF_dict = {}
for name, good_feature in loaded_dict.items():
good_name_front = name.rsplit('_', maxsplit=1)[-2]
if cls_dict[good_name_front] not in cls_goodF_dict:
cls_goodF_dict[cls_dict[good_name_front]] = {name:good_feature}
else:
cls_goodF_dict[cls_dict[good_name_front]][name] = good_feature
return cls_goodF_dict
except:
print("没有找到特征文件")
return mydict
def log_all(): # 字典优化后弃用
net = good_net.GoodNet()
net.load_state_dict(torch.load("weight/best_new957.pt"))
net.eval()
loggood_dict = load_dict_feature()
# 内存不够只能逐个文件夹存储,不然会爆
base_path = 'D:/zhrdpy_project/AutShop/data/log_img'
dir_paths = glob.glob(os.path.join(base_path, "*"))
for dir_path in dir_paths:
img_paths = glob.glob(os.path.join(dir_path, "*"))
for path in img_paths:
img_name = path.rsplit('\\', maxsplit=2)[-1]
name = img_name.split('.', maxsplit=2)[0]
img = cv2.imread(path)
image_tensor = img_transforms(img)
image_tensor = torch.unsqueeze(image_tensor,dim=0)
feature = net.get_feature(image_tensor)
add_dict_feature(feature, name, loggood_dict)
save_dict_feature(loggood_dict)
loggood_dict.clear() # 释放内存
loggood_dict = load_dict_feature()
'''
当我通过模型得到特征向量,并将其存储到字典中时,原先的特征向量(torch.Tensor对象)所占用的内存并不会立即被释放。
这是因为Python的垃圾回收机制(Garbage Collector,GC)并不保证在对象不再被引用时立即回收其占用的内存,尤其是在处理大型数据结构时。
GC的工作机制是周期性的,它会在内存压力达到一定程度或经过一定的时间间隔后运行,来清理不再使用的对象。
在我的场景中,每一次模型计算产生的torch.Tensor对象都会占用一定的内存空间,即使该对象随后被放入字典并可能被新的条目覆盖,
只要该对象还在某个地方被引用(就比如在我的字典中),它的内存就不会被立即释放。这意味着,如果我不清空字典,那些曾经存储在字典中的torch.Tensor对象,
它们的生命周期还没有结束(即还有其他引用指向它们),那么它们占用的内存就会持续存在,直到GC运行并确定它们确实不再被引用,才会回收这部分内存。
GC不会回收那些torch.Tensor对象所以我的内存就会爆,当我手动清空字典(使用dict.clear()或重新初始化字典)时,字典内部对所有torch.Tensor对象的引用都被移除,
此时如果这些对象没有其他外部引用,它们将变成孤立的对象,不再被任何变量引用,这就使得它们满足了被GC回收的条件。一旦GC运行,它就能检测到这些孤立的torch.Tensor对象,
并释放它们占用的内存。我需要的仅仅只是计算结果,所以重新读取的特征向量只有计算结果而其之前计算产生的torch.Tensor对象已经被回收了,我的内存就不会爆。
'''
def log_all_without_grad():
net = good_net.GoodNet()
net.load_state_dict(torch.load("weight/best_new957.pt"))
net.eval()
loggood_dict = {}
base_path = 'D:/zhrdpy_project/AutShop/data/log_img'
dir_paths = glob.glob(os.path.join(base_path, "*"))
for dir_path in dir_paths:
img_paths = glob.glob(os.path.join(dir_path, "*"))
for path in img_paths:
img_name = path.rsplit('\\', maxsplit=2)[-1]
name = img_name.split('.', maxsplit=2)[0]
img = cv2.imread(path)
image_tensor = img_transforms(img)
image_tensor = torch.unsqueeze(image_tensor,dim=0)
with torch.no_grad(): # 禁用梯度计算以节省内存
feature = net.get_feature(image_tensor)
add_dict_feature(feature, name, loggood_dict)
save_dict_feature(loggood_dict)
def log_single():
net = good_net.GoodNet()
net.load_state_dict(torch.load("weight/best_new957.pt"))
net.eval()
loggood_dict = load_dict_feature()
path = 'D:/zhrdpy_project/AutShop/data/log_img/xiang_jiao_niu_nai_he_zhuang/xiang_jiao_niu_nai_he_zhuang_1040.jpg'
img_name = path.rsplit('/', maxsplit=2)[-1]
name = img_name.split('.', maxsplit=2)[0]
img = cv2.imread(path)
image_tensor = img_transforms(img)
image_tensor = torch.unsqueeze(image_tensor,dim=0)
feature = net.get_feature(image_tensor)
add_dict_feature(feature, name, loggood_dict)
print(name + '成功注册')
save_dict_feature(loggood_dict)
if __name__ == '__main__':
log_all_without_grad()
# log_single()
# print('加载注册信息——————————————————————————————————————————')
show = load_dict_feature()
print(show)
注意:
- 当我通过模型得到特征向量,并将其存储到字典中时,原先的特征向量(torch.Tensor对象)所占用的内存并不会立即被释放。这是因为Python的垃圾回收机制(Garbage Collector,GC)并不保证在对象不再被引用时立即回收其占用的内存,尤其是在处理大型数据结构时。
GC的工作机制是周期性的,它会在内存压力达到一定程度或经过一定的时间间隔后运行,来清理不再使用的对象。在我的场景中,每一次模型计算产生的torch.Tensor对象都会占用一定的内存空间,即使该对象随后被放入字典并可能被新的条目覆盖,只要该对象还在某个地方被引用(就比如在我的字典中),它的内存就不会被立即释放。这意味着,如果我不清空字典,那些曾经存储在字典中的torch.Tensor对象,它们的生命周期还没有结束(即还有其他引用指向它们),那么它们占用的内存就会持续存在,直到GC运行并确定它们确实不再被引用,才会回收这部分内存。
GC不会回收那些torch.Tensor对象所以我的内存就会爆,当我手动清空字典(使用dict.clear()或重新初始化字典)时,字典内部对所有torch.Tensor对象的引用都被移除,此时如果这些对象没有其他外部引用,它们将变成孤立的对象,不再被任何变量引用,这就使得它们满足了被GC回收的条件。一旦GC运行,它就能检测到这些孤立的torch.Tensor对象,并释放它们占用的内存。我需要的仅仅只是计算结果,所以重新读取的特征向量只有计算结果而其之前计算产生的torch.Tensor对象已经被回收了,我的内存就不会爆。 - 另外可以直接禁用梯度计算(一开始没想到哈哈)
with torch.no_grad(): # 禁用梯度计算以节省内存和加速推理
analyse_good.py
作用:分析与跟踪商品
import torch
import torch.hub
import numpy as np
import cv2
import glob
import os
from PIL import Image, ImageDraw, ImageFont
from torchvision import transforms
from good_cls_data import *
from save_feature import *
from good_net import GoodNet
from ultralytics import YOLO
import torch.nn.functional as F
import matplotlib
import time
matplotlib.use('TkAgg')
FONT = ImageFont.truetype('simsun.ttc', size=30)
COLOR = ['blue', 'green', 'yellow', 'orange', 'purple', 'brown', 'red']
YOLO_LOAD = r'D:\zhrdpy_project\AutShop\weight\best_YOLOs.pt'
MY_NET_LOAD = "weight/best_new957.pt"
DETECT_TIMES = 30
# 定义一个训练的设备device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class GoodDetect:
def __init__(self):
# 加载目标侦测的模型
self.yolo_model = self.load_yolomodel().to(device)
# 加载特征检测的模型
self.screen_model = self.load_mymodel().to(device)
# 加载已注册商品
self.loggood_dict = load_dict_feature() # {cls:{'name':feature}}
# 暂存物体信息,判断是否新加
self.good_positions = []
self.good_names = []
self.goods_dic = {} # {id:{'name_cn':'xx', 'good_cls':good_cls, 'position':good_position, 'detect_times':1}}
self.id_num = 0
# 初始裁剪区域
self.region_mask = np.zeros((320, 320, 3), dtype=np.uint8)
def load_yolomodel(self):
model = YOLO(YOLO_LOAD)
model.conf = 0.85
return model
def load_mymodel(self):
net = GoodNet()
net.load_state_dict(torch.load(MY_NET_LOAD))
net.eval()
return net
def compera_good(self, target_good_feature, good_cls, threshold=0.96):
matching_goods = {} # 使用字典来记录每个商品名及其出现的次数
good_similarity = {} # 俩个出现相同次数输出最相似的
for cls, dic in self.loggood_dict.items():
if good_cls == cls :
# k = 0
for name, good_feature in dic.items():
similarity = F.cosine_similarity(target_good_feature, good_feature.to(device)).item()
if similarity >= threshold:
print('相似度达到阈值:' + name + ':' + str(similarity))
good_name_front = name.rsplit('_', maxsplit=1)[-2]
# 记录物品最大相似度
if good_name_front in good_similarity:
if similarity > good_similarity[good_name_front]:
good_similarity[good_name_front] = similarity
else:
good_similarity[good_name_front] = similarity
# 记录相似物品次数
if good_name_front in matching_goods:
matching_goods[good_name_front] += 1
else:
matching_goods[good_name_front] = 1
# k += 1
# print(f'检测次数{k}')
# 如果没有任何商品名达到阈值,返回'unknown'
if not matching_goods:
return 'unknown'
# 找出出现次数最多的商品名
most_common_goods = [k for k, v in matching_goods.items() if v == max(matching_goods.values())]
# print('出现次数最多:' + str(most_common_goods))
# 在出现次数相同的商品中选取相似度最大的
if len(most_common_goods) > 1:
most_similar_good = max(most_common_goods, key=lambda x: good_similarity[x])
return most_similar_good
else:
return most_common_goods[0]
"""弃用
def compera_good(self,target_good_feature,config=0.90):
max_similarity = -1 # 初始化最大相似度为负数,保证之后会被更新
good_name = None
for name, good_feature in self.loggood_dict.items():
similarity = F.cosine_similarity(target_good_feature, good_feature)
if similarity.item() > max_similarity:
max_similarity = similarity.item()
good_name = name
print(max_similarity)
if max_similarity < config:
good_name = 'unkown'
return good_name
return good_name
"""
def compare_position(self, box1, box2, iou_threshold=0.3):
x1_1, y1_1, x2_1, y2_1 = box1
x1_2, y1_2, x2_2, y2_2 = box2
# 计算两个矩形框的面积
area_1 = (x2_1 - x1_1) * (y2_1 - y1_1)
area_2 = (x2_2 - x1_2) * (y2_2 - y1_2)
# 计算交集区域
intersect_x1 = max(x1_1, x1_2)
intersect_y1 = max(y1_1, y1_2)
intersect_x2 = min(x2_1, x2_2)
intersect_y2 = min(y2_1, y2_2)
# 确保交集区域有效
intersection_area = max(0, intersect_x2 - intersect_x1) * max(0, intersect_y2 - intersect_y1)
# 计算IoU
iou = intersection_area / (area_1 + area_2 - intersection_area)
# 判断是否超过阈值,超过则可能是同一物体,返回T表明不是新商品而是同一商品
return iou >= iou_threshold
def get_region_mask(self,img,rect,box2,ls):
# 创建一个与原始图像大小相同的全黑掩码图像
b_mask = np.zeros(img.shape[:2], np.uint8)
# 绘制掩码轮廓
cv2.drawContours(b_mask, [ls], -1, (255, 255, 255), cv2.FILLED)
(height, width) = rect[1]
width = int(width)
height = int(height)
src_points = box2.squeeze().astype(np.float32)
dst_points = np.float32([[0, 0], [width, 0], [width, height], [0, height]])
# 计算透视变换矩阵
M = cv2.getPerspectiveTransform(src_points, dst_points)
# 创建一个与img大小相同的3通道图像,初始值为黑色
isolated_3_channel = np.zeros_like(img)
# 使用b_mask作为掩码,将原始图像img的像素复制到isolated_3_channel
isolated_3_channel[b_mask > 0] = img[b_mask > 0]
# 应用透视变换
warped_image = cv2.warpPerspective(isolated_3_channel, M, (width, height))
# 摆正的最小外接矩形
region = warped_image
if width > height:
region = cv2.rotate(region, cv2.ROTATE_90_COUNTERCLOCKWISE)
width, height = height, width
# 创建一个正方形底图
region_mask = np.zeros((height, height, 3), dtype=np.uint8)
# 计算region图片在底图上的起始位置
start_x = (height - width) // 2
start_y = (height - height) // 2
# 使用NumPy索引将region图片放置到底图上
region_mask[start_y:start_y + height, start_x:start_x + width, :] = region
# cv2.imshow('test',region_mask)
self.region_mask = region_mask
def region_to_tensor(self):
return img_transforms(trans_square(Image.fromarray(self.region_mask))).unsqueeze(dim=0)
def name_to_cn(self,good_name):
if good_name == 'unknown':
name_cn = '未注册商品'
else:
name_cn = [k for k, v in name_dict.items() if v == good_name][0]
return name_cn
def detect_vidio(self, img):
draw_img = np.copy(img)
results = self.yolo_model(img)
if results[0] is not None and results[0].masks is not None:
number = len(results[0].masks.xy)
good_names = []
goods_cls = results[0].boxes.cls.tolist()
for i in range(number):
new_good = False # 默认不是新商品
good_cls = results[0].names[goods_cls[i]]
contours = np.array(results[0].masks.xy[i], dtype=np.int32)
ls = contours.reshape(-1, 1, 2)
ln = contours.reshape(1, -1, 1, 2)
# 最小矩形
# 返回值格式为 (center, (width, height), angle),angle 描述的是最小外接矩形的长边(width 边)相对于最上边水平线的旋转角度。
rect = cv2.minAreaRect(ls)
box = cv2.boxPoints(rect)
box2 = np.int32(box).reshape(-1, 4, 2)
# 在原图上绘制最小外接矩形
img_contour = cv2.polylines(draw_img, box2, True, (0, 0, 255), 3)
img_contour = cv2.drawContours(draw_img, tuple(ln), -1, (0, 255, 0), 3)
# 最小矩形的最小矩形坐标
if number == 1:
x1, y1, x2, y2 = results[0].boxes.xyxy.cpu().numpy().squeeze().astype(np.int32)
else:
x1, y1, x2, y2 = results[0].boxes.xyxy.cpu().numpy().squeeze().astype(np.int32)[i]
# 记录此帧位置,最小矩形的最小矩形,保持画面水平一致
good_position = (x1, y1, x2, y2)
old_id = 0 # 记录需要操作的已有商品的id
# 与上一帧位置对比
if not self.goods_dic:
# 没有记录一定是新商品
new_good = True
else:
# 遍历字典
for id in self.goods_dic:
# 筛选同类商品
if self.goods_dic[id]['good_cls'] == good_cls:
# 同类商品进行IOU进一步判断是否是新商品
if self.compare_position(good_position,self.goods_dic[id]['position']):
new_good = False
old_id = id # 记录需要操作的已有商品的id
'''为使熔断机制提前放弃有问题的帧,对已有商品的确认操作应置后
# 如果是存在过的商品则需要跟踪检测,已存在物体进行多次复查
if self.goods_dic[id]['detect_times'] <= DETECT_TIMES:
# 裁剪出目标商品区域
self.get_region_mask(img, rect, box2, ls)
# 开始计时
start_time = time.time()
# 传递mask获取特征
with torch.no_grad(): # 禁用梯度计算以节省内存和加速推理
feature = self.screen_model.get_feature(self.region_to_tensor().to(device))
# 结束计时
end_time = time.time()
# 计算推理时间ms
inference_time = (end_time - start_time)*1000
# 比较余弦相似度获取名称
good_name = self.compera_good(feature,good_cls)
name_cn = self.name_to_cn(good_name)
# 判断之前检测结果是否有误
if name_cn == self.goods_dic[id]['name_cn']:
self.goods_dic[id]['detect_times'] += 1
print(f"本次模型处理{self.goods_dic[id]['name_cn']}用时: {inference_time:.1f} ms ---第{self.goods_dic[id]['detect_times']}次检测无误")
else:
print(f"本次模型处理{self.goods_dic[id]['name_cn']}用时: {inference_time:.1f} ms ---发现检测错误,已更正为{name_cn}!")
self.goods_dic[id]['name_cn'] = name_cn
self.goods_dic[id]['detect_times'] = 1
# 更新本帧位置信息
self.goods_dic[id]['position'] = good_position
# 已达侦测次数视为达标
else:
print(f'{self.goods_dic[id]["name_cn"]}达到检测次数{DETECT_TIMES}标准,已确认,不再检测')
'''
# 确定当前商品不是新商品即可直接break
break
else:
new_good = True
# 没有同类别商品一定是新的商品
else:
new_good = True
# 对新商品进行处理
if new_good:
# 添加进goods字典,新商品初始化
id = self.id_num
self.id_num += 1
self.goods_dic[id] = {'name_cn':'xx', 'good_cls':good_cls, 'position':good_position, 'detect_times':1}
# 裁剪出目标商品区域
self.get_region_mask(img, rect, box2, ls)
# 开始计时
start_time = time.time()
# 传递mask获取特征
with torch.no_grad(): # 禁用梯度计算以节省内存和加速推理
feature = self.screen_model.get_feature(self.region_to_tensor().to(device))
# 结束计时
end_time = time.time()
# 计算推理时间
inference_time = (end_time - start_time)*1000
# 比较余弦相似度获取名称
good_name = self.compera_good(feature,good_cls)
name_cn = self.name_to_cn(good_name)
self.goods_dic[id]['name_cn'] = name_cn
print(f"本次模型处理{self.goods_dic[id]['name_cn']}用时: {inference_time:.1f} ms ---侦测到新商品!!!")
# 熔断机制
if len(self.goods_dic) > number:
self.goods_dic.clear()
return draw_img, self.region_mask, ['请勿遮挡商品!']
# 对已有商品进行处理
if new_good == False:
id = old_id
# 如果是存在过的商品则需要跟踪检测,已存在物体进行多次复查
if self.goods_dic[id]['detect_times'] <= DETECT_TIMES:
# 裁剪出目标商品区域
self.get_region_mask(img, rect, box2, ls)
# 开始计时
start_time = time.time()
# 传递mask获取特征
with torch.no_grad(): # 禁用梯度计算以节省内存和加速推理
feature = self.screen_model.get_feature(self.region_to_tensor().to(device))
# 结束计时
end_time = time.time()
# 计算推理时间ms
inference_time = (end_time - start_time) * 1000
# 比较余弦相似度获取名称
good_name = self.compera_good(feature, good_cls)
name_cn = self.name_to_cn(good_name)
# 判断之前检测结果是否有误
if name_cn == self.goods_dic[id]['name_cn']:
self.goods_dic[id]['detect_times'] += 1
print(
f"本次模型处理{self.goods_dic[id]['name_cn']}用时: {inference_time:.1f} ms ---第{self.goods_dic[id]['detect_times']}次检测无误")
else:
print(
f"本次模型处理{self.goods_dic[id]['name_cn']}用时: {inference_time:.1f} ms ---发现检测错误,已更正为{name_cn}!")
self.goods_dic[id]['name_cn'] = name_cn
self.goods_dic[id]['detect_times'] = 1
# 更新本帧位置信息
self.goods_dic[id]['position'] = good_position
# 已达侦测次数视为达标
else:
print(f'{self.goods_dic[id]["name_cn"]}达到检测次数{DETECT_TIMES}标准,已确认,不再检测')
# 返回当前帧信息
if self.goods_dic:
for idx,id in enumerate(self.goods_dic):
good_names.append(self.goods_dic[id]['name_cn'])
draw_img = cv2ImgAddText(draw_img, self.goods_dic[id]['name_cn'], self.goods_dic[id]['position'][0], self.goods_dic[id]['position'][1] - 30, textColor=COLOR[idx % len(COLOR)])
return draw_img,self.region_mask,good_names
return draw_img,self.region_mask,['没有商品']
def detect_img(self, img, log=True, log_name='0'):
draw_img = np.copy(img)
region = img[0:320, 0:320]
region_mask = img[0:320, 0:320]
results = self.yolo_model(img)
if results[0] is not None and results[0].masks is not None:
number = len(results[0].masks.xy)
good_names = []
goods_cls = results[0].boxes.cls.tolist()
for i in range(number):
good_cls = results[0].names[goods_cls[i]]
contours = np.array(results[0].masks.xy[i], dtype=np.int32)
ls = contours.reshape(-1, 1, 2)
ln = contours.reshape(1, -1, 1, 2)
# 创建一个与原始图像大小相同的全黑掩码图像
b_mask = np.zeros(img.shape[:2], np.uint8)
# 绘制掩码轮廓
cv2.drawContours(b_mask, [ls], -1, (255, 255, 255), cv2.FILLED)
# 使用掩码图像抠出原始图像中的掩码区域(弃用2)
# extracted_img = cv2.bitwise_and(img, img, mask=b_mask)
# 掩码复合图(弃用)
# isolated = np.dstack([img, b_mask])
# 最小矩形
# 返回值格式为 (center, (width, height), angle),angle 描述的是最小外接矩形的长边(width 边)相对于最上边水平线的旋转角度。
rect = cv2.minAreaRect(ls)
box = cv2.boxPoints(rect)
box2 = np.int32(box).reshape(-1, 4, 2)
(height, width) = rect[1]
width = int(width)
height = int(height)
src_points = box2.squeeze().astype(np.float32)
dst_points = np.float32([[0, 0], [width, 0], [width, height], [0, height]])
# 计算透视变换矩阵
M = cv2.getPerspectiveTransform(src_points, dst_points)
# Create contour mask
_ = cv2.drawContours(b_mask, [ls], -1, (255, 255, 255), cv2.FILLED)
# 创建一个与img大小相同的3通道图像,初始值为黑色
isolated_3_channel = np.zeros_like(img)
# 使用b_mask作为掩码,将原始图像img的像素复制到isolated_3_channel
isolated_3_channel[b_mask > 0] = img[b_mask > 0]
# 应用透视变换
warped_image = cv2.warpPerspective(isolated_3_channel, M, (width, height))
# 摆正的最小外接矩形
region = warped_image
if width > height:
region = cv2.rotate(region, cv2.ROTATE_90_COUNTERCLOCKWISE)
width, height = height, width
# 创建一个正方形底图
region_mask = np.zeros((height, height, 3), dtype=np.uint8)
# 计算region图片在底图上的起始位置
start_x = (height - width) // 2
start_y = (height - height) // 2
# 使用NumPy索引将region图片放置到底图上
region_mask[start_y:start_y + height, start_x:start_x + width, :] = region
# cv2.imshow('test',region_mask)
warped_image = Image.fromarray(region_mask)
warped_image = trans_square(warped_image)
warped_image = img_transforms(warped_image)
warped_image = warped_image.unsqueeze(dim=0)
"""
# 弃用2
angle = rect[-1]
# 计算旋转角度,使短边为宽
if rect[1][0] > rect[1][1]:
angle = -(90 - angle)
# 获取旋转矩阵 正角度(正数)表示逆时针旋转。
# 负角度(负数)表示顺时针旋转。
center = (rect[0][0], rect[0][1])
M = cv2.getRotationMatrix2D(center, angle, 1.0)
# 旋转图像
(h, w) = extracted_img.shape[:2]
rotated = cv2.warpAffine(extracted_img, M, (w, h))
# 提取旋转后的矩形区域
width, height = rect[1][0], rect[1][1]
if width > height:
width, height = height, width
cropped_rotated = cv2.getRectSubPix(rotated, (int(width), int(height)), center)
# 将numpy数组转换为PIL图像
pil_image = Image.fromarray((cropped_rotated).astype('uint8'))
# 转换tensor
input_tensor = img_transforms(pil_image).unsqueeze(dim=0)
"""
if number == 1:
x1, y1, x2, y2 = results[0].boxes.xyxy.cpu().numpy().squeeze().astype(np.int32)
else:
x1, y1, x2, y2 = results[0].boxes.xyxy.cpu().numpy().squeeze().astype(np.int32)[i]
# 裁剪mask区域(弃用)
# iso_crop = isolated[y1:y2, x1:x2]
# iso_crop_3channel = iso_crop[:, :, :3]
# 开始计时
start_time = time.time()
# 传递mask获取特征
with torch.no_grad(): # 禁用梯度计算以节省内存和加速推理
feature = self.screen_model.get_feature(self.region_to_tensor().to(device))
# 结束计时
end_time = time.time()
# 计算推理时间
inference_time = (end_time - start_time)*1000
print(f"本次模型处理用时: {inference_time:.1f} ms")
if log == True:
good_name = log_name
cn_name = good_name
add_dict_feature(feature,good_name,self.loggood_dict)
good_names.append(good_name + '成功注册')
# save_dict_feature(self.loggood_dict)
else:
good_name = self.compera_good(feature,good_cls)
if good_name == 'unknown':
cn_name = '未注册商品'
else:
cn_name = [k for k, v in name_dict.items() if v == good_name][0]
good_names.append(cn_name)
# 绘制最小外接矩形
draw_img = cv2ImgAddText(draw_img,cn_name,x1,y1-30,textColor=COLOR[i % len(COLOR)])
img_contour = cv2.polylines(draw_img, box2, True, (0, 0, 255), 3)
img_contour = cv2.drawContours(draw_img, tuple(ln), -1, (0, 255, 0), 3)
# cv2.imwrite(base_path + '/' + f'{good_name}_{i}.jpg', region_mask, [cv2.IMWRITE_JPEG_QUALITY, 100])
return draw_img,region_mask,good_names
return draw_img,region,['没有商品']
def cv2ImgAddText(img, text, left, top, textColor=(0, 255, 0)):
if (isinstance(img, np.ndarray)): # 判断是否OpenCV图片类型
img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
# 创建一个可以在给定图像上绘图的对象
draw = ImageDraw.Draw(img)
# 字体的格式 超参数已设置
# 绘制文本
draw.text((left, top), text, textColor, font=FONT)
# 转换回OpenCV格式
return cv2.cvtColor(np.asarray(img), cv2.COLOR_RGB2BGR)
def trans_square(image):
img = image.convert('RGB')
img = np.array(img, dtype=np.uint8) # 图片转换成numpy
img_h, img_w, img_c = img.shape
if img_h != img_w:
# 宽和高的最大值和最小值
long_side = max(img_w, img_h)
short_side = min(img_w, img_h)
# (宽-高)除以 2
loc = abs(img_w - img_h) // 2
# 如果高是长边则换轴,最后再换回来 WHC
img = img.transpose((1, 0, 2)) if img_w < img_h else img
# 创建正方形背景
background = np.zeros((long_side, long_side, img_c), dtype=np.uint8)
# 数据填充在中间位置
background[loc:loc + short_side] = img[...]
# HWC
img = background.transpose((1, 0, 2)) if img_w < img_h else background
return Image.fromarray(img, 'RGB')
def get_log_img():
# 原图裁剪最小矩形进行注册
gd = GoodDetect()
base_path = 'D:/zhrdpy_project/AutShop/data/log_org'
img_paths = glob.glob(os.path.join(base_path, "*", "*"))
for img_path in img_paths:
img_name = img_path.rsplit('\\', maxsplit=1)[-1]
save_name = img_name.rsplit('.', maxsplit=1)[0]
save_path = img_path.rsplit('\\', maxsplit=1)[0]
good_name = img_path.rsplit('\\', maxsplit=2)[1]
img = cv2.imread(img_path)
draw_img, region, good_names = gd.detect_img(img, log=True, log_name=good_name)
cv2.imwrite(save_path + '/' + f'{good_name}_{save_name}.jpg', region, [cv2.IMWRITE_JPEG_QUALITY, 100])
# os.remove(img_path)
if __name__ == '__main__':
# get_log_img()
gd = GoodDetect()
base_path = 'D:/zhrdpy_project/AutShop/data/log_img_test/'
img_name = '2.jpg'
img_path = base_path + img_name
img = cv2.imread(img_path)
draw_img,region,good_names = gd.detect_img(img,log=False)
cv2.imshow("draw_img", draw_img)
cv2.imshow("region", region)
print(good_names)
cv2.waitKey(0)
shop_window.py
作用:程序的图形界面与接口
from tkinter import *
from tkinter import filedialog
from PIL import Image, ImageTk
from analyse_good import *
from good_cls_data import name_dict
import threading
lock = threading.Lock()
class Window_shop():
def __init__(self):
self.root = Tk()
self.img_Label = Label(self.root)
self.img_outLabel = Label(self.root)
self.txt = Text(self.root)
self.detect = GoodDetect()
self.img = None
self.Type = None
self.no_img = True
self.imgshow_width = 0
self.imgshow_height = 0
self.img_ratio = 0.5
self.var_name = StringVar() #文件输入路径变量
self.logname = 'logname'
self.cnname = 'cnname'
# 输入文件路径
def selectPath_file(self):
path_ = filedialog.askopenfilename(filetypes=[("图片或视频", [".jpg",".png", ".MOV", ".mp4"])])
self.var_name.set(path_)
self.Type = self.var_name.get().rsplit('.', maxsplit=2)[-1]
if self.Type == 'jpg' or self.Type == 'png':
self.no_img = False
self.img = cv2.imread(self.var_name.get())
img = Image.fromarray(cv2.cvtColor(self.img, cv2.COLOR_BGR2RGB))
self.img_width = int(img.width*self.img_ratio)
self.img_height = int(img.height*self.img_ratio)
img = img.resize((self.img_width, self.img_height), Image.ANTIALIAS)
photo = ImageTk.PhotoImage(img)
self.img_Label.config(image=photo)
self.img_Label.image = photo
if self.Type == "MOV" or self.Type == "mp4":
self.no_img = False
self.img = cv2.imread('data/vidio_show.jpg')
img = Image.fromarray(cv2.cvtColor(self.img, cv2.COLOR_BGR2RGB))
photo = ImageTk.PhotoImage(img)
self.img_Label.config(image=photo)
self.img_Label.image = photo
self.txt.delete(1.0, END) # 清除文本
export = self.var_name.get()
self.txt.insert(END, export) # 追加显示运算结果export
def detect_img_start(self):
self.txt.delete(1.0, END) # 清除文本
if self.no_img:
print("请选择图片或视频")
export = "请选择图片或视频"
self.txt.insert(END, export) # 追加显示运算结果export
return 0
draw_img,region,good_list = self.detect.detect_img(self.img,log=False)
export = str(good_list)
img_show = Image.fromarray(cv2.cvtColor(draw_img, cv2.COLOR_BGR2RGB))
img_out = Image.fromarray(cv2.cvtColor(region, cv2.COLOR_BGR2RGB))
img_show = img_show.resize((self.img_width, self.img_height), Image.ANTIALIAS)
photo = ImageTk.PhotoImage(img_show)
self.img_Label.config(image=photo)
self.img_Label.image = photo
img_out = img_out.resize((320, 320), Image.ANTIALIAS)
photo_out = ImageTk.PhotoImage(img_out)
self.img_outLabel.config(image=photo_out)
self.img_outLabel.image = photo_out
self.txt.insert(END, export) # 追加显示运算结果export
def login_start(self):
self.txt.delete(1.0, END) # 清除文本
if self.no_img:
print("请选择图片或视频")
export = "请选择图片或视频"
self.txt.insert(END, export) # 追加显示运算结果export
return 0
with lock:
if self.logname == 'logname' or self.cnname == 'cnname':
self.get_good_name()
if self.logname != 'logname' and self.cnname != 'cnname':
draw_img, region, good_list = self.detect.detect_img(self.img, log=True, log_name=self.logname)
export = str(good_list)
img_show = Image.fromarray(cv2.cvtColor(draw_img, cv2.COLOR_BGR2RGB))
img_out = Image.fromarray(cv2.cvtColor(region, cv2.COLOR_BGR2RGB))
img_show = img_show.resize((self.img_width, self.img_height), Image.ANTIALIAS)
photo = ImageTk.PhotoImage(img_show)
self.img_Label.config(image=photo)
self.img_Label.image = photo
# img_out = img_out.resize((320, 320), Image.ANTIALIAS)
photo_out = ImageTk.PhotoImage(img_out)
self.img_outLabel.config(image=photo_out)
self.img_outLabel.image = photo_out
self.txt.insert(END, export) # 追加显示运算结果export
# 注册复位
self.logname = 'logname'
self.cnname = 'cnname'
def log_mane(self,logname,cnname):
self.logname = logname
self.cnname = cnname
if cnname not in name_dict:
name_dict[cnname] = logname
def get_good_name(self):
namewin = Tk() # 调用tkinter模块中的TK()方法,实例化一个窗口对象
namewin.geometry("250x130") # 窗口对象调用geometry()方法,规划窗口大小
namewin.title("商品命名窗口")
# 设计提示标签 输入框 按钮
# 设计两个提示标签
logname = Label(namewin, text='注册名称', width=80)
cnname = Label(namewin, text='中文名称', width=80)
# 设计两个输入框
entlog = Entry(namewin, width=100)
entcn = Entry(namewin, width=100)
# 设计2个按钮
name_ok = Button(namewin, text='确认', command=lambda: self.log_mane(entlog.get(), entcn.get()))
namewin_quit = Button(namewin, text='关闭', command=lambda: namewin.destroy())
# --窗口各组件布局--
# 组件的窗口布局
logname.place(x=20, y=10, width=80, height=20)
cnname.place(x=20, y=40, width=80, height=20)
entlog.place(x=120, y=10, width=80, height=20)
entcn.place(x=120, y=40, width=80, height=20)
name_ok.place(x=100, y=80, width=50, height=20)
namewin_quit.place(x=170, y=80, width=50, height=20)
def choose_imgorvidio(self):
if self.Type == 'jpg' or self.Type == 'png':
self.detect_img_start()
if self.Type == "MOV" or self.Type == "mp4":
cap = cv2.VideoCapture(self.var_name.get())
while cap.isOpened():
retval, frame = cap.read()
if not retval:
print('can not read frame')
break
# 检测
draw_img,region,good_list = self.detect.detect_vidio(frame)
cv2.imshow("draw_img", draw_img)
cv2.imshow("region", region)
print(good_list)
key = cv2.waitKey(42)
if key == ord('q'):
break
# 释放资源
cap.release()
cv2.destroyAllWindows()
def run(self):
# 窗口
self.root.title('商品自动检测')
self.root.geometry('1000x800') # 这里的乘号不是 * ,而是小写英文字母 x
# 标题
lb_top = Label(self.root, text='商品自动检测程序',
bg='#d3fbfb',
fg='red',
font=('华文新魏', 32),
width=20,
height=2,
relief=SUNKEN)
lb_top.pack()
# 结果文本
self.txt.place(rely=0.8, relwidth=1, relheight=0.3)
# 按钮
btn2 = Button(self.root, text='开始检测', command=lambda: self.choose_imgorvidio()).place(relx=0.7, rely=0.14, relwidth=0.2, relheight=0.08)
btn1 = Button(self.root, text='开始注册', command=lambda: self.login_start()).place(relx=0.4, rely=0.14, relwidth=0.2, relheight=0.08)
btn0 = Button(self.root, text="路径选择", command=lambda: self.selectPath_file()).place(relx=0.1, rely=0.14, relwidth=0.2, relheight=0.08)
# 图像
self.img_Label.place(relx=0.05, rely=0.25, relwidth=0.65, relheight=0.5)
self.img_outLabel.place(relx=0.72, rely=0.25, relwidth=0.23, relheight=0.5)
self.root.mainloop()
if __name__ == '__main__':
win = Window_shop()
win.run()
test.py
作用:不包括在项目中,但可能用到的一些小方法
import os.path
import time
import glob
import os
import cv2
import torch
import good_net
def rename(img_folder):
for img_name in os.listdir(img_folder): # os.listdir(): 列出路径下所有的文件
#os.path.join() 拼接文件路径
src = os.path.join(img_folder, img_name) #src:要修改的目录名
image_name = '1000' + img_name
dst = os.path.join(img_folder, image_name) #dst: 修改后的目录名
os.rename(src, dst) #用dst替代src
def delete():
base_path = 'D:/zhrdpy_project/AutShop/data/log_img'
img_paths = glob.glob(os.path.join(base_path, "*", "*", "*"))
for img_path in img_paths:
img_name = img_path.rsplit('\\', maxsplit=1)[-1]
jpg = img_name.rsplit('.', maxsplit=1)[-1]
if jpg == 'jpg':
os.remove(img_path)
def png2jpg():
org_img_paths = glob.glob(os.path.join(r"D:\zhrdpy_project\AutShop\data\CLSDataset_test\box\xiang_jiao_niu_nai_he_zhuang", "*"))
for path in org_img_paths:
# png转jpg
img = cv2.imread(path)
image_name = path.rsplit('\\', maxsplit=1)[-1]
save_name = image_name.rsplit('.', maxsplit=1)[0]
targe_path = path.rsplit('\\', maxsplit=1)[0]
cv2.imwrite(targe_path+'/'+save_name+'.jpg', img, [cv2.IMWRITE_JPEG_QUALITY, 100])
os.remove(path)
def main():
img_folder0 = r'D:\zhrdpy_project\grayimg\label\lll' #文件夹路径 直接放文件夹路径即可\train&\test
rename(img_folder0)
def save_densnet():
gd = good_net.GoodNet()
gd.load_state_dict(torch.load("weight/best.pt")) # 加载最好权重
gd.eval()
print(gd)
# torch.save(gd.sub_net.state_dict(), 'weight/sub_net.pt') #保存densnet部分
# torch.save(gd.feature_net.state_dict(), 'weight/feature_net.pt') #保存feature_net部分
if __name__=="__main__":
save_densnet()
6 结语
一个简单的商品检测项目,数据集读者可以自行拍摄,录制视频抽帧即可训练模型了。
有什么交流意见可以评论或者私信我。
这里放一个展示视频:
商品检测效果视频