不废话了,直接上代码:
def load_imagepath_from_csv(csv_name):
image_path = []
with open(csv_name,'r') as file:
csv_reader = csv.reader(file)
next(csv_reader)
for row in csv_reader:
image_path.append(row[0])
return image_path
import csv
csv_name = "submission_demo.csv" #文件名仅供参考
image_path_list = load_imagepath_from_csv(csv_name)
image_path_list
from MMEdu import MMClassification as cls
import os
import csv
import numpy as np
rootpath = "test_image" #文件名仅供参考
csv_name = "submission_demo.csv" #文件名仅供参考
image_path_list = load_imagepath_from_csv(csv_name)
model = cls(backbone='MobileNet') #MobileNet也可以换成LeNet,ResNet18,ResNet50,RandForest等
checkpoint = 'checkpoints/cls_model/catsdogs_mobilenet_continue/best_accuracy_top-1_epoch_2.pth' #文件名仅供参考
predictions = []
for image_name in image_path_list:
image_path = rootpath+'/'+image_name
y_test_pred = model.inference(image=image_path, show=False, checkpoint=checkpoint, device='cpu')
y_test_pred = model.print_result(y_test_pred)
predictions.append(y_test_pred)
predictions
import csv
result_csv_path = 'inference_result/results1.csv'
with open(result_csv_path,"w",newline='') as csvfile:
csv_writer = csv.writer(csvfile)
csv_writer.writerow([f'filename','prediction','pre_class'])
for index.image_name in enumerate(image_path_list):
csv_writer.writerow([image_name,predictions[index][0]['标签'],predictions[index][0]['预测结果']])
csvfile.close()
import pandas as pd
df = pd.read_csv(result_csv_path, header=None)
df
下面是数据集(你也可以自己去网上搜图片):
这仅仅是MMEdu在图像分类上的一小部分作用,其他功能待大家发现!
本文内容为小编自己汇总,内容可能会有错误或疏漏,感谢大家的提议!
记得点赞和关注哦~