CNN的小体验

news2024/12/28 19:04:06

用的pytorch。

训练代码cnn.py:

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import torch.nn.functional as F

# 定义超参数
num_epochs = 10
batch_size = 100
learning_rate = 0.001

# 数据预处理和加载
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)

test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

# 定义卷积神经网络
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=5, stride=1, padding=2)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=2)
        self.fc1 = nn.Linear(32 * 8 * 8, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 32 * 8 * 8)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# 初始化模型、损失函数和优化器
model = CNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# 训练模型
for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(train_loader):
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if (i+1) % 100 == 0:
            print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader)}], Loss: {loss.item():.4f}')

# 测试模型
model.eval()
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_loader:
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    print(f'Accuracy of the model on the 10000 test images: {100 * correct / total}%')

# 保存模型
torch.save(model.state_dict(), 'cnn.pth')

推断代码cnn2.py

import torch
import torch.nn as nn
import torchvision.transforms as transforms
from PIL import Image
import torch.nn.functional as F

# 定义卷积神经网络(与之前的定义保持一致)
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=5, stride=1, padding=2) # 第一个卷积层
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0) # 池化层
        self.conv2 = nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=2) # 第二个卷积层
        self.fc1 = nn.Linear(32 * 8 * 8, 128) # 全连接层
        self.fc2 = nn.Linear(128, 10) # 输出层

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x))) # 通过第一个卷积层和池化层
        x = self.pool(F.relu(self.conv2(x))) # 通过第二个卷积层和池化层
        x = x.view(-1, 32 * 8 * 8) # 展平
        x = F.relu(self.fc1(x)) # 通过全连接层
        x = self.fc2(x) # 通过输出层
        return x

# 加载模型
model = SimpleCNN()
model.load_state_dict(torch.load('cnn.pth'))
model.eval() # 设置模型为评估模式

# 预处理图片
def preprocess_image(image_path):
    transform = transforms.Compose([
        transforms.Resize((32, 32)), # 调整图像大小到32x32
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    image = Image.open(image_path)
    image = transform(image)
    image = image.unsqueeze(0) # 增加批量维度
    return image

# 加载并预处理图片
image_path = 'test.jpg' # 替换为你要分析的图片路径
image = preprocess_image(image_path)

# 使用模型进行推理
with torch.no_grad():
    outputs = model(image)
    _, predicted = torch.max(outputs.data, 1)
    class_index = predicted.item()

# CIFAR-10类别标签
classes = ['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

# 输出预测结果
print(f'Predicted class: {classes[class_index]}')

可惜出来的东西跟弱智一般。

python3 cnn2.py
Predicted class: horse

python3 cnn2.py
Predicted class: bird

几个小点:

1 使用的数据集是CIFAR10

2 训练真的挺耗时的,我用的阿里云,一共搞了差不多10分钟(训练一个弱智)。

3 环境依然麻烦,python,numpy的版本都不能太高。否则要出问题。。。

4 最后实事求是的说,我不太懂的一点是怎么分类出来的。。。晚点再看看。。。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1883297.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

2024第17届中国西部(重庆)留学移民海外置业展览会

2024第17届中国西部(重庆)留学移民海外置业展览会 邀请函 主办单位: 中国西部教体医融合博览会组委会 承办单位:重庆中博展览有限公司 展会背景: 成都和重庆是中国新一线城市,是西部经济的核心增长极&a…

samba服务的搭建与使用

关闭selinux #暂时关闭selinux 查看selinux状态 [rootlocalhost ~]# getenforce Disabled [rootlocalhost ~]# 如果此处是‘enforcing’,则执行下列代码 [rootlocalhost ~]# setenforce 0 再次查看selinux状态 [rootlocalhost ~]# getenforce permissive #永久关…

舞会无领导:一种树形动态规划的视角

没有上司的舞会 Ural 大学有 𝑁 名职员,编号为1∼𝑁。 他们的关系就像一棵以校长为根的树,父节点就是子节点的直接上司。 每个职员有一个快乐指数,用整数 𝐻𝑖 给出,其中1≤&…

【Llama 2的使用方法】

Llama 2是Meta AI(Facebook的母公司Meta的AI部门)开发并开源的大型语言模型系列之一。Llama 2是在其前身Llama模型的基础上进行改进和扩展的,旨在提供更强大的自然语言处理能力和更广泛的应用场景。 以下是Llama 2的一些关键特性和更新点&am…

1Python的Pandas:基本简介

1. Pandas的简介 Pandas 是一个开源的 Python 数据分析库,由 Wes McKinney 在 2008 年开始开发,目的是为了解决数据分析任务中的各种需求。Pandas 是基于 NumPy 库构建的,它使得数据处理和分析工作变得更加快速和简单。Pandas 提供了易于使用…

mac|浏览器链接不上服务器但可以登微信

千万千万千万不要没有关梯子直接关机,不然就会这样子呜呜呜 设置-网络,点击三个点--选择--位置--编辑位置(默认是自动) 新增一个,然后选中点击完成 这样就可以正常上网了

网络编程:UDP编程笔记

1.字节序的概念和转换 小端格式: 低位字节数据存储在低地址 大端格式: 高位字节数据存储在低地址 在主机上时为小端存储,在网络上时为大端,所以接收到数据时,要转为小端口 如下图: #include <arpa/inet.h> 发送者调用的函数: uint32_t htonl(uint32_t hostlong); //转…

【工具推荐】ONLYOFFICE8.1版本编辑器测评——时下的办公利器

文章目录 一、产品介绍1. ONLYOFFICE 8.1简介2. 多元化多功能的编辑器 二、产品体验1. 云端协作空间2. 桌面编辑器本地版 三、产品界面设计1. 本地版本2. 云端版本 四、产品文档处理1. 文本文档&#xff08;Word)2. 电子表格&#xff08;Excel&#xff09;3. PDF表单&#xff0…

Linux——移动文件或目录,查找文件,which命令

移动文件或目录 作用 - mv命令用于剪切或重命名文件 格式 bash mv [选项] 源文件名称 目标文件名称 注意 - 剪切操作不同于复制操作&#xff0c;因为它会把源文件删除掉&#xff0c;只保留剪切后的文件。 - 如果在同一个目录中将某个文件剪切后还粘贴到当前目录下&#xff0c;…

芒果YOLOv10改进122:注意力机制系列:最新结合即插即用CA(Coordinate attention) 注意力机制,CVPR 顶会助力分类检测涨点!

论文所提的Coordinate注意力很简单,可以灵活地插入到经典的移动网络中,而且几乎没有计算开销。大量实验表明,Coordinate注意力不仅有益于ImageNet分类,而且更有趣的是,它在下游任务(如目标检测和语义分割)中表现也很好。本文结合目标检测任务应用 应专栏读者的要求,写一…

Jasper studio报表工具中,如何判断subDataSource()子报表数据源是否为空

目录 1.1、错误描述 1.2、解决方案 1.1、错误描述 今天在处理一个有关Jasper Studio报表模板制作的线上问题&#xff0c;需要根据某个报表子数据源是否为空&#xff0c;来决定对应的组件是否显示&#xff0c;找了好久的资料都没有实现&#xff0c;最后找到一种解决办法。就是…

MySQL架构和性能优化

文章目录 一、MySQL架构架构图存储引擎MyISAM引擎特点InnoDB引擎特点管理存储引擎 二、性能优化索引索引管理EXPLAIN 工具使用profile工具 监控 一、MySQL架构 架构图 存储引擎 MySQL提供了多种存储引擎供用户选择&#xff0c;每种存储引擎都有自己的特点和使用场景。 InnoDB…

【机器学习】FFmpeg+Whisper:二阶段法视频理解(video-to-text)大模型实战

目录 一、引言 二、FFmpeg工具介绍 2.1 什么是FFmpeg 2.2 FFmpeg核心原理 2.3 FFmpeg使用示例 三、FFmpegWhisper二阶段法视频理解实战 3.1 FFmpeg安装 3.2 Whisper模型下载 3.3 FFmpeg抽取视频的音频 3.3.1 方案一&#xff1a;命令行方式使用ffmpeg 3.3.2 方案二&a…

深入剖析Tomcat(十四) Server、Service 组件:如何启停Tomcat服务?

通过前面文章的学习&#xff0c;我们已经了解了连接器&#xff0c;四大容器是如何配合工作的&#xff0c;在源码中提供的示例也都是“一个连接器”“一个顶层容器”的结构。并且启动方式是分别启动连接器和容器&#xff0c;类似下面代码 connector.setContainer(engine); try …

MATLAB|更改绘图窗口的大小和位置

MATLAB绘图 plot、plot3、cdfplot都适用 效果 如下图&#xff0c;运行程序后可以直接得到这两个绘图窗口。 右上角的Figure1是原始图片&#xff0c;右下角的Figure2是调整了位置和大小后的绘图窗口。 完整源代码 % 绘图大小和位置调整 % Evand©2024 % 2024-7-1/Ver1…

代码随想录算法训练营第59天:动态[1]

代码随想录算法训练营第59天&#xff1a;动态 两个字符串的删除操作 力扣题目链接(opens new window) 给定两个单词 word1 和 word2&#xff0c;找到使得 word1 和 word2 相同所需的最小步数&#xff0c;每步可以删除任意一个字符串中的一个字符。 示例&#xff1a; 输入: …

MySQL 常见存储引擎详解(一)

本篇主要介绍MySQL中常见的存储引擎。 目录 一、InnoDB引擎 简介 特性 最佳实践 创建InnoDB 存储文件 二、MyISAM存储引擎 简介 特性 创建MyISAM表 存储文件 存储格式 静态格式 动态格式 压缩格式 三、MEMORY存储引擎 简介 特点 创建MEMORY表 存储文件 内…

【postgresql】版本学习

PostgreSQL 17 Beta 2 发布于2024-06-27。 PostgreSQL 17 Beta 2功能和变更功能的完整列表&#xff1a;PostgreSQL: Documentation: 17: E.1. Release 17 ​ 支持的版本&#xff1a; 16 ( 当前版本) / 15 / 14 / 13 / 12 ​ 不支持的版本&#xff1a; 11 / 10 / 9.6 / 9.5 /…

UE4_材质_材质节点_Fresnel

学习笔记&#xff0c;不喜勿喷&#xff0c;侵权立删&#xff0c;祝愿生活越来越好&#xff01; 一、问题导入 在创建电影或过场动画时&#xff0c;你常常需要想办法更好地突显角色或场景的轮廓。这时你需要用到一种光照技术&#xff0c;称为边沿光照或边缘光照&#xff0c;它…

Spring Cloud Circuit Breaker基础入门与服务熔断

官网地址&#xff1a;https://spring.io/projects/spring-cloud-circuitbreaker#overview 本文SpringCloud版本为&#xff1a; <spring.boot.version>3.1.7</spring.boot.version> <spring.cloud.version>2022.0.4</spring.cloud.version>【1】Circu…