【PyTorch】图像二分类项目-部署

news2024/9/21 19:08:27

【PyTorch】图像二分类项目

【PyTorch】图像二分类项目-部署

在独立于训练脚本的新脚本中部署用于推理的模型,需要构造一个模型类的对象,并将权重加载到模型中。操作流程为:定义模型--加载权重--在验证和测试数据集上部署模型。

import torch.nn as nn
import numpy as np
# 设置随机种子
np.random.seed(0)
import torch.nn as nn
import torch.nn.functional as F

# 定义一个函数,用于计算卷积层的输出形状
def findConv2dOutShape(H_in,W_in,conv,pool=2):
    # 获取卷积核的大小
    kernel_size=conv.kernel_size
    # 获取卷积的步长
    stride=conv.stride
    # 获取卷积的填充
    padding=conv.padding
    # 获取卷积的扩张
    dilation=conv.dilation

    # 计算卷积后的高度
    H_out=np.floor((H_in+2*padding[0]-dilation[0]*(kernel_size[0]-1)-1)/stride[0]+1)
    # 计算卷积后的宽度
    W_out=np.floor((W_in+2*padding[1]-dilation[1]*(kernel_size[1]-1)-1)/stride[1]+1)

    # 如果pool不为空
    if pool:
        # 将H_out除以pool
        H_out/=pool
        W_out/=pool
    # 返回H_out和W_out的整数形式
    return int(H_out),int(W_out)

class Net(nn.Module):
    def __init__(self, params):
        super(Net, self).__init__()
    
        # 获取输入形状
        C_in,H_in,W_in=params["input_shape"]
        # 获取初始滤波器数量
        init_f=params["initial_filters"] 
        # 获取第一个全连接层神经元数量
        num_fc1=params["num_fc1"]  
        # 获取类别数量
        num_classes=params["num_classes"] 
        # 获取模型的dropout率,是0到1间的浮点数
        # Dropout是一种正则化技术,随机关闭部分神经元(输出设为0),防止过拟合,提高泛化能力
        self.dropout_rate=params["dropout_rate"] 
        
        # 定义第一个卷积层
        self.conv1 = nn.Conv2d(C_in, init_f, kernel_size=3)
        # 计算第一个卷积层的输出形状
        h,w=findConv2dOutShape(H_in,W_in,self.conv1)
        
        self.conv2 = nn.Conv2d(init_f, 2*init_f, kernel_size=3)
        h,w=findConv2dOutShape(h,w,self.conv2)
        self.conv3 = nn.Conv2d(2*init_f, 4*init_f, kernel_size=3)
        h,w=findConv2dOutShape(h,w,self.conv3)
        self.conv4 = nn.Conv2d(4*init_f, 8*init_f, kernel_size=3)
        h,w=findConv2dOutShape(h,w,self.conv4)
        
        # 计算全连接层的输入形状
        self.num_flatten=h*w*8*init_f
        
        # 定义第一个全连接层
        self.fc1 = nn.Linear(self.num_flatten, num_fc1)
        
        self.fc2 = nn.Linear(num_fc1, num_classes)
    
    # 定义前向传播函数,接收输入x
    def forward(self, x):
        # 第一个卷积层
        x = F.relu(self.conv1(x))
        # 第一个池化层
        x = F.max_pool2d(x, 2, 2)
        
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)

        x = F.relu(self.conv3(x))
        x = F.max_pool2d(x, 2, 2)

        x = F.relu(self.conv4(x))
        x = F.max_pool2d(x, 2, 2)
        
        # 将卷积层的输出展平
        x = x.view(-1, self.num_flatten)
        
        # 第一个全连接层
        x = F.relu(self.fc1(x))
        # Dropout层
        x=F.dropout(x, self.dropout_rate)
        # 第二个全连接层
        x = self.fc2(x)

        # 返回输入x应用对数软最大变换后的输出
        # log-softmax对数软最大值函数,常用于计算交叉熵损失函数(cross-entropy loss),因为交叉熵损失函数需要计算概率的对数。
        # dim参数指定了在哪个维度上应用log-softmax。例如,如果dim=1,则对每一行应用log-softmax。
        return F.log_softmax(x, dim=1)


# 定义模型参数
params_model={
    # 输入形状
    "input_shape": (3,96,96),
    # 初始过滤器数量
    "initial_filters": 8, 
    # 全连接层1的神经元数量
    "num_fc1": 100,
    # Dropout率
    "dropout_rate": 0.25,
    # 类别数量
    "num_classes": 2,
}

# 创建一个CNN模型,参数为params_model
cnn_model = Net(params_model)

import torch
# 权重文件路径
path2weights="./models/weights.pt"

# 加载权重文件
cnn_model.load_state_dict(torch.load(path2weights))

# 进入评估模式
cnn_model.eval()

# 移动模型至cuda设备
if torch.cuda.is_available():
    device = torch.device("cuda")
    cnn_model=cnn_model.to(device) 
import time 

# 定义一个函数,用于部署模型
def deploy_model(model,dataset,device, num_classes=2,sanity_check=False):
    # num_classes:类别数,默认为2
    # sanity_check:是否进行完整性检查,默认为False
    pass
    # 获取数据集长度
    len_data=len(dataset)
    
    # 初始化输出张量
    y_out=torch.zeros(len_data,num_classes)
    
    # 初始化真实标签张量
    y_gt=np.zeros((len_data),dtype="uint8")
    
    # 将模型移动到指定设备
    model=model.to(device)
    
    # 存储每次推理的时间
    elapsed_times=[]
    
    # 在不计算梯度的情况下执行以下代码
    with torch.no_grad():
        for i in range(len_data):
            # 获取数据集中的一个样本
            x,y=dataset[i]
            # 将真实标签存储到张量中
            y_gt[i]=y
            # 记录开始时间
            start=time.time()    
            # 进行推理
            y_out[i]=model(x.unsqueeze(0).to(device))
            # 计算推理时间
            elapsed=time.time()-start
            # 将推理时间存储到列表中
            elapsed_times.append(elapsed)

            # 如果进行完整性检查,则只进行一次推理
            if sanity_check is True:
                break

    # 计算平均推理时间
    inference_time=np.mean(elapsed_times)*1000
    # 打印平均推理时间
    print("average inference time per image on %s: %.2f ms " %(device,inference_time))
    # 返回推理结果和真实标签
    return y_out.numpy(),y_gt

import torch
from PIL import Image
from torch.utils.data import Dataset
import pandas as pd
import torchvision.transforms as transforms
import os

# 设置随机种子,使得每次运行代码时生成的随机数相同
torch.manual_seed(0)

class histoCancerDataset(Dataset):
    def __init__(self, data_dir, transform,data_type="train"):      
        # 获取数据目录
        path2data=os.path.join(data_dir,data_type)

        # 获取数据目录下的所有文件名
        self.filenames = os.listdir(path2data)

        # 获取数据目录下的所有文件的完整路径
        self.full_filenames = [os.path.join(path2data, f) for f in self.filenames]

        # 获取标签文件名
        csv_filename=data_type+"_labels.csv"
        # 获取标签文件的完整路径
        path2csvLabels=os.path.join(data_dir,csv_filename)
        # 读取标签文件
        labels_df=pd.read_csv(path2csvLabels)

        # 将标签文件的索引设置为文件名
        labels_df.set_index("id", inplace=True)

        # 获取每个文件的标签
        self.labels = [labels_df.loc[filename[:-4]].values[0] for filename in self.filenames]

        # 获取数据转换函数
        self.transform = transform
    
    def __len__(self):
        # 返回数据集的长度
        return len(self.full_filenames)
    
    def __getitem__(self, idx):
        # 根据索引获取图像
        image = Image.open(self.full_filenames[idx])  
        # 对图像进行转换
        image = self.transform(image)
        # 返回图像和标签
        return image, self.labels[idx]

import torchvision.transforms as transforms
# 创建一个数据转换器,将数据转换为张量
data_transformer = transforms.Compose([transforms.ToTensor()])

data_dir = "./data/"
# 传入数据目录、数据转换器和数据集类型
histo_dataset = histoCancerDataset(data_dir, data_transformer, "train")
# 打印数据集的长度
print(len(histo_dataset))

from torch.utils.data import random_split

# 获取数据集的长度
len_histo=len(histo_dataset)
# 训练集取数据集的80%
len_train=int(0.8*len_histo)
# 验证集取数据集的20%
len_val=len_histo-len_train

# 将数据集随机分割为训练集和验证集
train_ds,val_ds=random_split(histo_dataset,[len_train,len_val])

# 打印训练集和验证集的长度
print("train dataset length:", len(train_ds))
print("validation dataset length:", len(val_ds))

 

# 部署模型 
y_out,y_gt=deploy_model(cnn_model,val_ds,device=device,sanity_check=False)
# 打印输出和真实值的形状
print(y_out.shape,y_gt.shape)

使用预测输出计算模型在验证数据集上的精度

from sklearn.metrics import accuracy_score

# 获取预测
y_pred = np.argmax(y_out,axis=1)
print(y_pred.shape,y_gt.shape)

# 计算精度 
acc=accuracy_score(y_pred,y_gt)
print("accuracy: %.2f" %acc)

 

# 部署在CPU上
device_cpu = torch.device("cpu")
y_out,y_gt=deploy_model(cnn_model,val_ds,device=device_cpu,sanity_check=False)
print(y_out.shape,y_gt.shape)

复制data文件夹中的sample_submission.csv文件并命名为test_labels.csv

path2csv="./data/test_labels.csv"
# 读取csv文件,并存储到DataFrame中
labels_df=pd.read_csv(path2csv)
# 显示DataFrame的前几行
labels_df.head()

data_dir = "./data/"
# 创建测试数据集
histo_test = histoCancerDataset(data_dir, data_transformer,data_type="test")
# 打印测试数据集的长度
print(len(histo_test))

 

# 用测试数据集部署
y_test_out,_=deploy_model(cnn_model,histo_test, device, sanity_check=False)

# 使用np.argmax函数对y_test_out进行操作,得到y_test_pred
y_test_pred=np.argmax(y_test_out,axis=1)

# 打印y_test_pred的形状
print(y_test_pred.shape)

from torchvision import utils

import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
np.random.seed(0)

# 定义一个函数,用于显示图像和标签
def show(img,y,color=True):
    # 将图像转换为numpy数组
    npimg = img.numpy()

    # 将图像的维度从(C,H,W)转换为(H,W,C)
    npimg_tr=np.transpose(npimg, (1,2,0))
    
    # 如果color为False,则将图像转换为灰度图像
    if color==False:
        npimg_tr=npimg_tr[:,:,0]
        plt.imshow(npimg_tr,interpolation='nearest',cmap="gray")
    else:
        # 否则,直接显示图像
        plt.imshow(npimg_tr,interpolation='nearest')
    # 显示图像的标签
    plt.title("label: "+str(y))
    
# 定义一个网格大小
grid_size=4
# 随机选择grid_size个图像的索引
rnd_inds=np.random.randint(0,len(histo_test),grid_size)
print("image indices:",rnd_inds)

# 从histo_test中获取grid_size个图像
x_grid_test=[histo_test[i][0] for i in range(grid_size)]
# 从y_test_pred中获取grid_size个标签
y_grid_test=[y_test_pred[i] for i in range(grid_size)]

# 将grid_size个图像组合成一个网格
x_grid_test=utils.make_grid(x_grid_test, nrow=4, padding=2)
print(x_grid_test.shape)

# 设置图像的大小
plt.rcParams['figure.figsize'] = (10.0, 5)
# 显示图像和标签
show(x_grid_test,y_grid_test)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1939002.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Windows11 安装Docker,安装至D盘(其他非C盘皆可)

Docker默认安装在C盘,这未来随着docker使用必定会导致C盘空间吃紧。 所以本文提前进行空间布局,将docker默认安装路径软链接到D盘。 软链接D盘 Docker默认安装路径为C:\Program Files\Docker。使用管理员权限打开命令终端 输入以下命令:m…

【LeetCode】day14:226 - 翻转二叉树, 101 - 对称二叉树, 104 - 二叉树的最大深度, 111 - 二叉树的最小深度

LeetCode 代码随想录跟练 Day14 226.翻转二叉树101.对称二叉树104.二叉树的最大深度111.二叉树的最小深度 226.翻转二叉树 题目描述: 给你一棵二叉树的根节点 root ,翻转这棵二叉树,并返回其根节点。 使用递归处理(迭代以及层序同…

jdk1.8 List集合Stream流式处理

jdk1.8 List集合Stream流式处理 一、介绍(为什么需要流Stream,能解决什么问题?)1.1 什么是 Stream?1.2 常见的创建Stream方法1.3 常见的中间操作1.4 常见的终端操作 二、创建流Stream2.1 Collection的.stream()方法2.2 数组创建流2.3 静态工厂…

单链表的创建与遍历--C

基本结构声明 struct node{int data; //数据域struct node *next;//指针域 }; #include<stdio.h> #include<stdlib.h>struct node{//链表结点 int data;//数据域 struct node *next;//指针域 }; typedef struct node Node; int main(void){Node *head,*p,*…

【高数复盘】武忠祥高数辅导讲义+严选题错题一轮复盘

第一章 函数、极限和连续 高等数学辅导讲义 1. 复盘&#xff1a;(xsinxcosx)(x-sixcosx)&#xff0c;前者可以带入cosx1&#xff0c;而后者不能带入&#xff0c;为何&#xff1f; 2. 复盘&#xff1a; 这道题很明显an≤1&#xff0c;对于直接求极限&#xff0c;可以考虑夹逼…

华为防火墙总部与分支机构建立IPsec VPN涉及NAT穿越

一、IPsec VPN基本概念 1、隧道建立方式&#xff1a;分为手动建立和IKE自动协商&#xff0c;手动建立需要人为配置指定所有IPsec建立的所有参数信息&#xff0c;不支持为动态地址的发起方&#xff0c;实际网络中很少应用&#xff1b;IKE协议是基于密钥管理协议ISAKMP框架设计而…

linux系统设置开机启动的两种方法systemd及rc.local(手工写sh脚本,手工写service服务)

文章目录 知识点实验一、systemd&#xff08;一&#xff09;自写一个sh脚本并加入开机启动&#xff08;二&#xff09;源码安装的nginx加入开机启动 rc.local 知识点 在Linux系统中&#xff0c;有多种方法可以设置开机启动。以下是其中的一些主要方法&#xff1a; systemd 在较…

本地部署 mistralai/Mistral-Nemo-Instruct-2407

本地部署 mistralai/Mistral-Nemo-Instruct-2407 1. 创建虚拟环境2. 安装 fschat3. 安装 transformers4. 安装 flash-attn5. 安装 pytorch6. 启动 controller7. 启动 mistralai/Mistral-Nemo-Instruct-24078. 启动 api9. 访问 mistralai/Mistral-Nemo-Instruct-2407 1. 创建虚拟…

[图解]《分析模式》漫谈16-“我用的”不能变成“我的”

1 00:00:00,720 --> 00:00:02,160 今天&#xff0c;我们来说一下 2 00:00:02,170 --> 00:00:04,850 “我用的”不能变成“我的” 3 00:00:04,860 --> 00:00:11,390 《分析模式》的前言 4 00:00:12,260 --> 00:00:13,410 有这么一句话 5 00:00:14,840 --> 0…

postman接口测试实战篇

击杀小游戏接口测试 接口测试简单介绍击杀小游戏代码下载单接口测试(postman)接口关联并参数化接口测试简单介绍 首先思考两个问题:1.接口是什么?2.接口测试是什么? 1.我们总是把接口想的很复杂,其实呢,它就是一个有特定输入和输出参数的交互逻辑处理单元,它不需要知…

学并发编程前需要明确的一些基础知识

线程和进程的区别 在计算机科学中&#xff0c;线程和进程是两个非常重要的概念。虽然它们常常被一起提到&#xff0c;但它们实际上有很大的不同。作为一个开发者&#xff0c;我在日常工作中经常需要区分这两者&#xff0c;以便更好地进行资源管理和优化。 进程与线程的基本定…

如何解决微服务下引起的 分布式事务问题

一、什么是分布式事务&#xff1f; 虽然叫分布式事务&#xff0c;但不是一定是分布式部署的服务之间才会产生分布式事务。不是在同一个服务或同一个数据库架构下&#xff0c;产生的事务&#xff0c;也就是分布式事务。 跨数据源的分布式事务 跨服务的分布式事务 二、解决方…

华为机试HJ60查找组成一个偶数最接近的两个素数

华为机试HJ60查找组成一个偶数最接近的两个素数 题目&#xff1a; 想法&#xff1a; 构建一个判断是否为素数的函数&#xff0c;再构建一个函数输出构成输入数值相差最小的两个素数。为了保证两个素数相差最小&#xff0c;从输入数值的二分之一处开始判断&#xff0c;遍历得到…

用Python写一个视频采集脚本,对某网站进行批量采集

最近某牙上又出现一批高质量视频&#xff0c;听说删的很快&#xff0c;还好我会Python&#xff0c;赶紧采集下来保存&#xff01; 准备工作 环境使用 Python 3.10 解释器 Pycharm 编辑器 模块使用 requests >>> 数据请求模块 re <正则表达式模块> os <文…

HW行动在即,邮件系统该怎么防守?

1. 什么是HW行动&#xff1f; HW行动是一项由公安部牵头&#xff0c;旨在评估企事业单位网络安全防护能力的活动&#xff0c;是国家应对网络安全问题所做的重要布局之一。 具体实践中&#xff0c;公安部组织攻防红、蓝两队&#xff08;红队为攻击队&#xff0c;主要由“国家队…

【漏洞复现】Netgear WN604 downloadFile.php 信息泄露漏洞(CVE-2024-6646)

0x01 产品简介 NETGEAR WN604是一款由NETGEAR&#xff08;网件&#xff09;公司生产的无线接入器&#xff08;或无线路由器&#xff09;提供Wi-Fi保护协议&#xff08;WPA2-PSK, WPA-PSK&#xff09;&#xff0c;以及有线等效加密&#xff08;WEP&#xff09;64位、128位和152…

面向初学者和专家的 40 大机器学习问答(2024 年更新)

面向初学者和专家的 40 大机器学习问答(2024 年更新) 一、介绍 机器学习是人工智能的重要组成部分,目前是数据科学中最受欢迎的技能之一。如果你是一名数据科学家,你需要擅长 python、SQL 和机器学习——没有两种方法。作为 DataFest 2017 的一部分,我们组织了各种技能测…

正则表达式(Ⅰ)——基本匹配

学习练习建议 正则表达式用途非常广泛&#xff0c;各种语言中都能见到它的身影&#xff08;js&#xff0c;java&#xff0c;mysql等&#xff09; 正则表达式可以快读校验/生成/替换符合要求的模式的字符串&#xff0c;而且语法通俗易懂&#xff0c;所以应用广泛 学习链接&am…

php随机海量高清壁纸系统源码,数据采集于网络,使用很方便

2022 多个分类随机海量高清壁纸系统源码&#xff0c;核心文件就两个&#xff0c;php文件负责采集&#xff0c;html负责显示&#xff0c;很简单。做流量工具还是不错的。 非第三方接口&#xff0c;图片数据采集壁纸多多官方所有数据&#xff01; 大家拿去自行研究哈&#xff0…

WEB前端09-前端服务器搭建(Node.js/nvm/npm)

前端服务器的搭建 在本文中&#xff0c;我们将介绍如何安装和配置 nvm&#xff08;Node Version Manager&#xff09;以方便切换不同版本的 Node.js&#xff0c;以及如何设置 npm&#xff08;Node Package Manager&#xff09;使用国内镜像&#xff0c;并搭建一个简单的前端服…