pytorch训练和使用resnet

news2024/11/25 12:49:16

pytorch训练和使用resnet

使用 CIFAR-10数据集

训练 resnet

resnet-train.py

import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim

# 在CIFAR-10数据集中
# 训练集:包含50000张图像,用于训练模型。
# 测试集:包含10000张图像,用于评估模型的性能。
TRAIN_SIZE=50000
TEST_SIZE=10000

# 批量大小
BATCH_SIZE=128

# 数据预处理
transform_train = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32, padding=4),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# 加载CIFAR-10数据集
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=BATCH_SIZE,
                                         shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

# 使用预训练的ResNet模型 , 不从默认url下载预训练的模型
model = torchvision.models.resnet18(weights=None)
# 从当前路径加载预训练权重
model_path = './model/resnet18-f37072fd.pth'
model.load_state_dict(torch.load(model_path))

# 修改最后一层以适应CIFAR-10的10个类别
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 10)

# 将模型移到GPU(如果有)
if torch.cuda.is_available() :
    print('Using GPU')
    device = torch.device("cuda:0")
else :
    print('Using CPU')
    device = torch.device("cpu")   

model = model.to(device)

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=5e-4)

# 学习率调度器
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)

# 训练网络
num_epochs = 50

print('start Training')

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0

    #总迭代次数 = 训练集大小 / 批量大小 =  向上取整(TRAIN_SIZE=50000 / BATCH_SIZE=128) = 391 次循环
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)

        # 梯度清零
        optimizer.zero_grad()

        # 前向传播 + 向后传播 + 优化
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # 打印统计信息
        running_loss += loss.item()
        if i % 100 == 99:    # 每100个小批量打印一次
            print(f'[Epoch {epoch + 1}, Batch {i + 1}] loss: {running_loss / 100:.3f}')
            running_loss = 0.0

    # 更新学习率
    scheduler.step()

print('Finished Training')

# 测试网络
model.eval()
correct = 0
total = 0
with torch.no_grad():
    # 总迭代次数 = 测试集 / 批量大小 向上取整(TEST_SIZE=10000/BATCH_SIZE=128) = 79 次循环
    for data in testloader:
        images, labels = data
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

accuracy_test = 100 * correct / total
print(f'Accuracy of the network on the 10000 test images: {accuracy_test:.2f}%')

# [Epoch 50, Batch 300] loss: 0.142
# Finished Training
# Accuracy of the network on the 10000 test images: 84.53%


# 准确率>0.8保存模型
if(accuracy_test > 0.8):
    print("Accuracy  > 0.8 ,save model")
    model_path = './model/trained_resnet18_cifar10.pth'
    torch.save(model.state_dict(), model_path)
    print(f'Model saved to {model_path}')

使用训练后的 resnet

评估数据
1.jpeg :

请添加图片描述

2.jpeg:

请添加图片描述

restnet-eval.py

import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
from PIL import Image

# 模型路径
model_path = './model/trained_resnet18_cifar10.pth'

# 类别标签
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

# 数据预处理
transform = transforms.Compose([
    transforms.Resize((32, 32)),  # 调整图像大小为32x32
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # 归一化
])

# 加载预训练的ResNet模型
model = torchvision.models.resnet18(pretrained=False)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 10)
model.load_state_dict(torch.load(model_path))
model.eval()  # 设置模型为评估模式

# 将模型移到GPU(如果有)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.to(device)

def predict_image(image_path):
    # 加载并预处理图像
    image = Image.open(image_path).convert('RGB')
    image = transform(image).unsqueeze(0)  # 添加批次维度
    image = image.to(device)

    # 进行预测
    with torch.no_grad():
        outputs = model(image)
        _, predicted = torch.max(outputs.data, 1)

    # 输出预测结果
    predicted_class = classes[predicted.item()]
    print(f'Predicted class: {predicted_class}')

# img is in classes
predict_image('./data/1.jpeg')

# img is not in classes
predict_image('./data/2.jpeg')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2223514.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

城市极客,存内先锋-存内社区主理人招募令

在这个数据驱动的时代,存内计算正成为推动技术革新的核心力量。 我们,存内计算社区,正站在这场革命的前沿,现在,我们正式发出召集令,寻找那些渴望引领技术浪潮的城市站主理人! 地点&#xff1a…

使用LangChain进行LLM应用开发(1)——了解LangChain

【课程链接】https://www.ai360labs.com/playground/course/66813572135124992/detail 【适用人群】 入门学习Langchain的同学轻体验ChatOpenAI的同学,平台提供Api-key,应该是很小的token额度,仅供练习 LangChain是一个开源框架&#xff0c…

【机器学习基础】全连接层

1. 定义: 每一个节点都跟其后面所有的神经元相连两层之间所有神经元都有权重连接,通常全连接层在卷积神经网络尾部也就是跟传统的神经网络神经元的连接方式是一样的 2. 作用: 全连接层(fully connected layers,FC)在整个卷积神经网络中起到“分类器”的作用。如果说卷积层、…

使用 NCC 和 PKG 打包 Node.js 项目为可执行文件(Linux ,macOS,Windows)

🎬 江城开朗的豌豆:个人主页 🔥 个人专栏 :《 VUE 》 《 javaScript 》 📝 个人网站 :《 江城开朗的豌豆🫛 》 ⛺️ 生活的理想,就是为了理想的生活 ! 目录 📘 文章引言 步骤 1:…

linux:线程id及线程互斥

线程的tid不像进程&#xff0c;那不是他真正的id&#xff0c;也不是内核的lwp&#xff0c;而是由pthread库维护的一个唯一值 给用户提供的线程ID&#xff0c;不是内核中的lwp&#xff0c;而是pthread库维护的一个唯一值 库内部也要承担对线程的管理 #include<stdio.h>…

2024青龙面板京东教程

一、接上文《2024最新青龙面板安装教程》 腾讯云轻量服务器2核2G4M&#xff0c;只要79一年&#xff0c;可续费一次。 购买地址&#xff1a;https://curl.qcloud.com/LpLkvjq1 二、拉库 拉库前请打开青龙面板-配置文件 第18行 GithubProxyUrl“” 双引号中的内容清空 复制以…

代理 IP 对于鸿蒙开发者的意义与帮助

华为推出的鸿蒙操作系统以其独特的分布式架构和强大的性能&#xff0c;吸引了众多开发者的目光。而在鸿蒙开发的过程中&#xff0c;代理 IP 技术也发挥着重要的作用&#xff0c;为开发者带来了诸多意义与帮助。 一、提供更广泛的测试环境 对于鸿蒙开发者来说&#xff0c;确保应…

高效数据集成:聚水潭采购入库单与金蝶云星空

聚水潭采购入库单与金蝶云星空的高效数据集成案例分享 在企业日常运营中&#xff0c;采购入库单的数据处理和管理是至关重要的一环。为了实现聚水潭采购入库单到金蝶云星空的无缝对接&#xff0c;我们采用了轻易云数据集成平台&#xff0c;成功配置并运行了“聚水潭采购入库单…

钉钉录播抓取视频

爬取钉钉视频 免责声明 此脚本仅供学习参考&#xff0c;切勿违法使用下载他人资源进行售卖&#xff0c;本人不但任何责任! 仓库地址: GItee 源码仓库 执行顺序 poxyM3u8开启代理getM3u8url用于获取m3u8文件userAgent随机请求头downVideo|downVideoThreadTqdm单线程下载和…

荣耀MagicOS 9.0发布会及开发者大会丨一图读懂应用服务及商业合作分论坛

更多优质流量变现服务&#xff0c;可点击荣耀广告变现服务查看&#xff1b; 荣耀远航计划——应用市场【耀闪行动】全新上线&#xff0c;更多激励及资源扶持可点击荣耀应用市场耀闪行动查看。

Zookeeper实战 集群环境部署

1、概述 今天我们来学习一下Zookeeper集群相关的内容&#xff0c;本文主要的内容有集群环境的搭建&#xff0c;集群常见的问题和对应的解决方案。 2、集群环境搭建 2.1、准备工作 首先我们准备好安装包&#xff0c;创建好集群部署的路径。将解压后的安装文件复制三分。这里…

水轮发电机油压自动化控制系统解决方案介绍

在现代水电工程中&#xff0c;水轮机组油压自动化控制系统&#xff0c;不仅直接关系到水轮发电机组的安全稳定运行&#xff0c;还影响着整个水电站的生产效率和经济效益。 一、系统概述 国科JSF油压自动控制系统&#xff0c;适用于水轮发电机组调速器油压及主阀&#xff08;蝶…

【功能安全】 独立于环境的安全要素SEooC

目录 01 SEooC定义 02 SEooC开发步骤 03 SEooC开发示例 04 SEooC问答 01 SEooC定义 缩写: SEooC:Safety Element out of Context独立于环境的安全要素 SEooC出处:GB/T34590.10—2022,第9章节 SEooC与相关项什么关系? SEooC可以是系统、系统组合、子系统、软件组件、…

【Unity】游戏UI中添加粒子特效导致穿层问题的解决

这里介绍一下简易的ui系统中&#xff0c;添加粒子特效导致的穿层问题 首先是在ui界面中添加粒子特效预制体&#xff0c;这个时候&#xff0c;控制这个粒子显示层级的有两个方面 上图中&#xff0c;如果你的Sorting Layer ID的值&#xff08;Layer排序&#xff09;是大于当前C…

安康旅游指南:基于SpringBoot的网站开发实践

第一章 绪论 1.1 研究现状 时代的发展&#xff0c;我们迎来了数字化信息时代&#xff0c;它正在渐渐的改变着人们的工作、学习以及娱乐方式。计算机网络&#xff0c;Internet扮演着越来越重要的角色&#xff0c;人们已经离不开网络了&#xff0c;大量的图片、文字、视频冲击着我…

可视化大屏的C位放啥(04):园区鸟瞰图,方寸之间显示海量数据

在可视化大屏的 C 位放置园区鸟瞰图&#xff0c;可谓独具匠心。 鸟瞰图以宏观视角呈现整个园区风貌&#xff0c;让人一眼便对园区布局有清晰认知。同时&#xff0c;通过巧妙的设计&#xff0c;可在这方寸之间展示海量数据。 如园区内各企业的生产数据、能耗情况、人员流动等信…

【学习笔记】强化学习

李宏毅深度强化学习 笔记 课程主页&#xff1a;NTU-MLDS18 视频&#xff1a;youtube B站 参考资料&#xff1a; 作业代码参考 纯numpy实现非Deep的RL算法 OpenAI tutorial 文章目录 李宏毅深度强化学习 笔记1. Introduction2. Policy Gradient2.1 Origin Policy Gradient2.2…

Matlab 车牌识别技术

1.1设计内容及要求&#xff1a; 课题研究的主要内容是对数码相机拍摄的车牌&#xff0c;进行基于数字图像处理技术的车牌定位技术和车牌字符分割技术的研究与开发&#xff0c;涉及到图像预处理、车牌定位、倾斜校正、字符分割等方面的知识,总流程图如图1-1所示。 图1-1系统总…

【问题解决】Flink在linux上运行成功但是无法访问webUI界面

一&#xff0c;问题 在搭建Flink的时候&#xff0c;已经在linux服务器上运行了./start-cluster.sh&#xff0c; 而且日志显示已经成功了。 服务器上也没有开启防火墙 正常来说应该能通过ip:8081来访问(8081是Flink WebUI的默认端口)&#xff0c;但是访问的时候&#xff0c;显示…

Redis未授权访问及配合SSRF总结

Redis是一个开源的内存数据库&#xff0c;它用于存储数据&#xff0c;并提供高性能、可扩展性和丰富的数据结构支持。 Redis复现文章较全 Redisssrf漏洞利用探测内网 RedisInsight/RedisDesktopManager可视化连接工具 漏洞原理 &#xff08;1&#xff09;redis绑定在 0.0.…