现在学AI的一个优势就是:前人栽树后人乘凉,很多资料都已完善,而且有很多很棒的开源作品可以学习,感谢大佬们
项目
项目源码地址
视频教程地址
我在大佬的基础上基于此模型还加上了根据特征值缓存进行快速识别的方法,以应对超市某些未能正确识别的场景,针对项目中window.py文件进行修改和补充:
- 缓存查询方法
def query_cache(self, image_features):
if not cache:
return None
max_similarity = -1
best_label = None
for image_id, (cached_features, label) in cache.items():
similarity = self.cosine_similarity(image_features, cached_features)
if similarity > max_similarity:
max_similarity = similarity
best_label = label
if max_similarity >= 0.5:
return best_label
else:
return None
- 余弦相邻计算
def cosine_similarity(self, features1, features2):
dot_product = np.dot(features1.flatten(), features2.flatten())
norm_features1 = np.linalg.norm(features1)
norm_features2 = np.linalg.norm(features2)
return dot_product / (norm_features1 * norm_features2)
- 缓存更新方法
def update_cache(self):
input_text = self.input_box.text()
self.label = input_text or self.label
self.result.setText(self.label)
# 如果缓存已满,移除最久未使用的条目
if len(cache) >= CACHE_CAPACITY:
cache.popitem(last=False)
# 添加新条目
cache[self.image_id] = (self.image_features, self.label)
self.input_box.clear()
self.class_names.append(self.label)
- 获取图片哈希值
def get_image_id_from_hash(self, img):
buffer = img.tobytes()
return hashlib.md5(buffer).hexdigest()
- 预测图片
def predict_img(self):
self.input_box.clear()
img = Image.open('images/target.png') # 读取图片
self.image_id = self.get_image_id_from_hash(img)
img = np.asarray(img) # 将图片转化为numpy的数组
start_time = time.time() # 记录开始时间
outputs = self.model.predict(img.reshape(1, 224, 224, 3), batch_size=1, ) # 将图片输入模型得到结果
end_time = time.time() # 记录结束时间
elapsed_time = end_time - start_time # 计算时间差
print("运行时间:", elapsed_time, "秒")
self.image_features = outputs
result = self.query_cache(outputs)
self.label = result
if result is None:
result_index = int(np.argmax(outputs))
result = self.class_names[result_index] # 获得对应的水果名称
self.result.setText(result)
self.label = result
else:
self.result.setText(result) # 在界面上做显示
- UI改造
def initUI(self):
main_widget = QWidget()
main_layout = QHBoxLayout()
font = QFont('楷体', 15)
# 主页面,设置组件并在组件放在布局上
left_widget = QWidget()
left_layout = QVBoxLayout()
img_title = QLabel("样本")
img_title.setFont(font)
img_title.setAlignment(Qt.AlignCenter)
self.img_label = QLabel()
img_init = cv2.imread(self.to_predict_name)
h, w, c = img_init.shape
scale = 400 / h
img_show = cv2.resize(img_init, (0, 0), fx=scale, fy=scale)
cv2.imwrite("images/show.png", img_show)
img_init = cv2.resize(img_init, (224, 224))
cv2.imwrite('images/target.png', img_init)
self.img_label.setPixmap(QPixmap("images/show.png"))
left_layout.addWidget(img_title)
left_layout.addWidget(self.img_label, 1, Qt.AlignCenter)
left_widget.setLayout(left_layout)
right_widget = QWidget()
right_layout = QVBoxLayout()
btn_change = QPushButton(" 上传图片 ")
btn_change.clicked.connect(self.change_img)
btn_change.setFont(font)
btn_predict = QPushButton(" 开始识别 ")
btn_predict.setFont(font)
btn_predict.clicked.connect(self.predict_img)
btn_update = QPushButton(" 更新缓存 ")
btn_update.setFont(font)
btn_update.clicked.connect(self.update_cache)
label_result = QLabel(' 果蔬名称 ')
self.result = QLabel("等待识别")
label_result.setFont(QFont('楷体', 16))
self.result.setFont(QFont('楷体', 24))
self.input_box = QLineEdit()
self.input_box.setPlaceholderText("请输入内容...")
right_layout.addStretch()
right_layout.addWidget(label_result, 0, Qt.AlignCenter)
right_layout.addStretch()
right_layout.addWidget(self.result, 0, Qt.AlignCenter)
right_layout.addStretch()
right_layout.addWidget(self.input_box)
right_layout.addStretch()
right_layout.addWidget(btn_change)
right_layout.addWidget(btn_predict)
right_layout.addWidget(btn_update)
right_layout.addStretch()
right_widget.setLayout(right_layout)
main_layout.addWidget(left_widget)
main_layout.addWidget(right_widget)
main_widget.setLayout(main_layout)
# 关于页面,设置组件并把组件放在布局上
label_super = QLabel("作者:cpa") # todo 更换作者信息
label_super.setFont(QFont('楷体', 12))
# label_super.setOpenExternalLinks(True)
label_super.setAlignment(Qt.AlignRight)
# 添加注释
self.addTab(main_widget, '主页')
self.setTabIcon(0, QIcon('images/主页面.png'))
运行:
待完善:
- 缓存部分目前市面上是需要将缓存值存入本地sqllite之类的数据库进行保存的,这样下次开机缓存数据不会丢失,这里只展示思路
- 缓存除了保存在本地外还可以上传云端进行增强学习然后下发最新模型在本地进行更新,形成完美闭环
打包
pip install pyinstaller
pyinstaller -F -w (-i icofile) filename
说明:
filename表示你的Python程序文件名
-w 表示隐藏程序运行时的命令行窗口(不加-w会有黑色窗口)
括号内的为可选参数,-i icofile表示给程序加上图标,图标必须为.ico格式
icofile表示图标的位置,建议直接放在程序文件夹里面,这样子打包的时候直接写文件名就好
pyinstaller -F -w -i 'test.ico' window.py
- 将图片文件和模型文件等window.py中用到的资源在打包后一起移入dist目录中,不然会资源找不到的错
- 发给你的小伙伴看看效果吧