使用pytorch利用神经网络原理进行图片的训练(持续学习中....)

news2025/1/18 7:37:48

1.做这件事的目的
语言只是工具,使用python训练图片数据,最终会得到.pth的训练文件,java有使用这个文件进行图片识别的工具,顺便整合,我觉得Neo4J正确率太低了,草莓都能识别成为苹果,而且速度慢,不能持续识别视频帧

2.什么是神经网络?(其实就是数学的排列组合最终得到统计结果的概率)

1.先把二维数组转为一维
2.通过公式得到节点个数和值
3…同2
4.通过节点得到概率(softmax归一化公式)
5.对比模型的和 差值=原始概率-目标结果概率
6.不断优化原来模型的概率
5.激活函数,激活某个节点的函数,可以引入非线性的(因为所有问题不可能是线性的比如 很少图片识别一定可以识别出绝对的正方形,他可能中间有一定弯曲或者线在中心短开了)

在这里插入图片描述
在这里插入图片描述

3.训练的代码
//环境python3.8 最好使用conda进行版本管理,不然每个版本都可能不兼容,到处碰壁

 #安装依赖
 pip install numpy torch torchvision matplotlib

#文件夹结构,图片一定要是28x28的
在这里插入图片描述

import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
import matplotlib.pyplot as plt
from torchvision.datasets.folder import ImageFolder

class Net(torch.nn.Module):

    def __init__(self):
        super().__init__()
        self.fc1 = torch.nn.Linear(28 * 28, 64)
        self.fc2 = torch.nn.Linear(64, 64)
        self.fc3 = torch.nn.Linear(64, 64)
        self.fc4 = torch.nn.Linear(64, 10)

    def forward(self, x):
        x = torch.nn.functional.relu(self.fc1(x))
        x = torch.nn.functional.relu(self.fc2(x))
        x = torch.nn.functional.relu(self.fc3(x))
        x = torch.nn.functional.log_softmax(self.fc4(x), dim=1)
        return x

#导入数据
def get_data_loader(is_train):
     #张量,多维数组
    to_tensor = transforms.Compose([transforms.ToTensor()])
     # 下载数据集 下载目录
    data_set = MNIST("", is_train, transform=to_tensor, download=True)
     #一个批次15张,顺序打乱
    return DataLoader(data_set, batch_size=15, shuffle=True)

def get_image_loader(folder_path):
    to_tensor = transforms.Compose([transforms.ToTensor()])
    data_set = ImageFolder(folder_path, transform=to_tensor)
    return DataLoader(data_set, batch_size=1)

#评估准确率
def evaluate(test_data, net):
    n_correct = 0
    n_total = 0
    with torch.no_grad():
        #按批次取数据
        for (x, y) in test_data:
            #计算神经网络预测值
            outputs = net.forward(x.view(-1, 28 * 28))

            for i, output in enumerate(outputs):
                #比较预测结果和测试集结果
                if torch.argmax(output) == y[i]:
                    #统计正确预测结果数
                    n_correct += 1
                #统计全部预测结果
                n_total += 1
        #返回准确率=正确/全部的
    return n_correct / n_total


def main():
    #加载训练集
    train_data = get_data_loader(is_train=True)
    #加载测试集
    test_data = get_data_loader(is_train=False)
    #初始化神经网络
    net = Net()
    #打印测试网络的准确率 0.1
    print("initial accuracy:", evaluate(test_data, net))
    #训练神经网络
    optimizer = torch.optim.Adam(net.parameters(), lr=0.001)
    #重复利用数据集 2次
    for epoch in range(100):
        for (x, y) in train_data:
            #初始化 固定写法
            net.zero_grad()
            #正向传播
            output = net.forward(x.view(-1, 28 * 28))
            #计算差值
            loss = torch.nn.functional.nll_loss(output, y)
            #反向误差传播
            loss.backward()
            #优化网络参数
            optimizer.step()
        print("epoch", epoch, "accuracy:", evaluate(test_data, net))
    # #使用3张图片进行预测
    # for (n, (x, _)) in enumerate(test_data):
    #     if n > 3:
    #         break
    #     predict = torch.argmax(net.forward(x[0].view(-1, 28 * 28)))
    #     plt.figure(n)
    #     plt.imshow(x[0].view(28, 28))
    #     plt.title("prediction: " + str(int(predict)))
    # plt.show()
    image_loader = get_image_loader("aa")

    for (n, (x, _)) in enumerate(image_loader):
        if n > 2:
            break
        predict = torch.argmax(net.forward(x.view(-1, 28 * 28)))
        plt.figure(n)
        plt.imshow(x[0].permute(1, 2, 0))
        plt.title("prediction: " + str(int(predict)))
    plt.show()


if __name__ == "__main__":
    main()


#运行结果 弹框出现图片和识别结果

4.测试电脑的cuda是否安装成功,不成功不能运行下面的代码

import torch

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
print('CUDA version:', torch.version.cuda)
print('PyTorch version:', torch.__version__)

5.在gpu上运行,需要去官网下载cuda安装
https://developer.nvidia.com/cuda-toolkit-archive
#并且需要安装和torch对应的版本,我的电脑是1660ti的所以安装了10.2的cuda
#安装torchgpu版本

pip install torch==1.9.0+cu102 -f
https://download.pytorch.org/whl/cu102/torch_stable.html

import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
import matplotlib.pyplot as plt
from torchvision.datasets.folder import ImageFolder

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class Net(torch.nn.Module):

    def __init__(self):
        super().__init__()
        self.fc1 = torch.nn.Linear(28 * 28, 64)
        self.fc2 = torch.nn.Linear(64, 64)
        self.fc3 = torch.nn.Linear(64, 64)
        self.fc4 = torch.nn.Linear(64, 10)

    def forward(self, x):
        x = torch.nn.functional.relu(self.fc1(x))
        x = torch.nn.functional.relu(self.fc2(x))
        x = torch.nn.functional.relu(self.fc3(x))
        x = torch.nn.functional.log_softmax(self.fc4(x), dim=1)
        return x

def get_data_loader(is_train):
    to_tensor = transforms.Compose([transforms.ToTensor()])
    data_set = MNIST("", is_train, transform=to_tensor, download=True)
    return DataLoader(data_set, batch_size=15, shuffle=True)

def get_image_loader(folder_path):
    to_tensor = transforms.Compose([transforms.ToTensor()])
    data_set = ImageFolder(folder_path, transform=to_tensor)
    return DataLoader(data_set, batch_size=1)

def evaluate(test_data, net):
    n_correct = 0
    n_total = 0
    with torch.no_grad():
        for (x, y) in test_data:
            x, y = x.to(device), y.to(device)
            outputs = net.forward(x.view(-1, 28 * 28))
            for i, output in enumerate(outputs):
                if torch.argmax(output.cpu()) == y[i].cpu():
                    n_correct += 1
                n_total += 1
    return n_correct / n_total

def main():
    train_data = get_data_loader(is_train=True)
    test_data = get_data_loader(is_train=False)
    net = Net().to(device)
    print("initial accuracy:", evaluate(test_data, net))
    optimizer = torch.optim.Adam(net.parameters(), lr=0.001)
    for epoch in range(100):
        for (x, y) in train_data:
            x, y = x.to(device), y.to(device)
            net.zero_grad()
            output = net.forward(x.view(-1, 28 * 28))
            loss = torch.nn.functional.nll_loss(output, y)
            loss.backward()
            optimizer.step()
        print("epoch", epoch, "accuracy:", evaluate(test_data, net))
    image_loader = get_image_loader("aa")

    for (n, (x, _)) in enumerate(image_loader):
        if n > 2:
            break
        x = x.to(device)
        predict = torch.argmax(net.forward(x.view(-1, 28 * 28)).cpu())
        plt.figure(n)
        plt.imshow(x[0].permute(1, 2, 0).cpu())
        plt.title("prediction: " + str(int(predict)))
    plt.show()

if __name__ == "__main__":
    main()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1234992.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Vue项目 配置项设置

一、项目运行时浏览器自动打开 找到package.json文件 找到"sctipts"配置项 在"serve"配置项最后加上--open "scripts": {"serve": "vue-cli-service serve --open","build": "vue-cli-service build&quo…

2023年【四川省安全员A证】复审考试及四川省安全员A证考试试题

题库来源:安全生产模拟考试一点通公众号小程序 四川省安全员A证复审考试根据新四川省安全员A证考试大纲要求,安全生产模拟考试一点通将四川省安全员A证模拟考试试题进行汇编,组成一套四川省安全员A证全真模拟考试试题,学员可通过…

Nacos介绍与使用

Nacos介绍与使用 文章目录 Nacos介绍与使用一. 什么是Nacos1 Nacos功能1.1 配置中心1.2 注册中心 2.为什么要使用Nacos 二.Nacos 部署安装1. Nacos 部署方式2. Nacos 安装3. 配置数据源4. 开启控制台授权登录(可选) 三. Nacos配置中心的使用1. 创建配置信…

2023/11/21JAVAweb学习

优先级高低id > 类 > 元素 格式化ctrl alt L

LeetCode热题100——动态规划

动态规划 1. 爬楼梯2. 杨辉三角3. 打家劫舍 1. 爬楼梯 假设你正在爬楼梯。需要 n 阶你才能到达楼顶。 每次你可以爬 1 或 2 个台阶。你有多少种不同的方法可以爬到楼顶呢? // 题解:每次都有两种选择,1或者2 int climbStairs(int n) {if (n …

GAMES101—Lec 05~06:光栅化

目录 概念回顾(个人理解)光栅化1.采样2.采样出现的问题:走样 反走样 概念回顾(个人理解) 屏幕:在图形学中,我们认为屏幕是一个二维数组,数组里的每一个元素为一个二维像素。 光栅化…

【C++进阶之路】第四篇:set和map

文章目录 一、关联式容器健值对二、set & multiset三、map & multimap在这里插入图片描述 四、set和map底层原理 一、关联式容器健值对 关联式容器 & 键值对 二、set & multiset set & multiset 三、map & multimap map & multimap 四、set和…

【AT模式连接ONENET】ONENET可视化平台的使用

02 ONENET可视化平台的使用 ATCWMODE1 设置模式 ATCWDHCP1,1 启动DHCP功能 ①ATCWJAP"ssid","password" ATCWJAP“123456789”,“wang020118” ②ATMQTTUSERCFG0,1,"设备名字","设备ID","你的鉴权信息""…

JAVA项目测试----用户管理系统

一)项目简介: 用户管理系统是依据于前后端分离来实现的,是基于Spring SpringBoot Spring MVC,SpringAOP,MyBatis等框架来实现的一个用户管理网站,并且已经部署到了云服务器上, 目前的用户管理系统实现了超级管理员的注册功能&…

模电 01

一.半导体基本知识 1.优点:体积小、重量轻、使用寿命长、输入功率小、功率转换效率高。 2.性能介于导体与绝缘体 3.常用半导体材料:硅(SI) 镉(Ge),化合物半导体:砷化镓(GaAs&…

【封装UI组件库系列】全局样式的定义与重置

封装UI组件库系列第二篇样式​​​​​​​ ​​​​​​🌟前言 🌟定义全局样式 生成主题色和不同亮度的颜色 ​编辑 中性色及其他变量 🌟样式重置 🌟总结 ​​​​​​​​​​​​​​🌟前言 在前端开发中&…

SpringBoot趣探究--1.logo是如何打印出来的

一.前言 从本篇开始,我将对springboot框架做一个有趣的探究,探究一下它的流程,虽然源码看不懂,不过我们可以一点一点慢慢深挖,好了,下面我们来看一下本篇的知识,这个logo是如何打印出来的&#…

2014年2月24日 Go生态洞察:FOSDEM 2014上的Go演讲精选

🌷🍁 博主猫头虎(🐅🐾)带您 Go to New World✨🍁 🦄 博客首页——🐅🐾猫头虎的博客🎐 🐳 《面试题大全专栏》 🦕 文章图文…

【Echart】Echart设置label太长隐藏:

文章目录 第一种:竖排显示第二种:显示部分第三种:强制显示所有标签并旋转 第一种:竖排显示 xAxis: {type: category,data: res.data.data.sz.xAxis,axisLabel:{fontSize:12,formatter: function(value) {return value.split().joi…

gitlab安装以及创建用户创建组,修改密码 邮箱配置 数据备份与恢复--保姆级教学!

GitLab是一种基于Web的Git仓库管理工具,它允许您在组织或个人级别上创建和管理Git仓库,以便在一个中心位置上执行代码管理和协作工作。GitLab提供了强大的功能,如代码审查、问题跟踪、CI/CD、容器注册表、Wiki和持续集成等。 以下是GitLab的…

gitlab安装配置及应用

安装 ##安装依赖 yum install -y curl policycoreutils-python openssh-server perl#上传包 rz gitlab-jh-16.5.2-jh.0.el7.x86_64.rpm 安装 yum install gitlab-jh-16.0.3-jh.0.el7.x86_64.rpm 初始化并启动 # 以下两种方法都可以配置访问地址,第一种需要在yum安…

2023年【A特种设备相关管理(锅炉压力容器压力管道)】模拟考试题及A特种设备相关管理(锅炉压力容器压力管道)作业考试题库

题库来源:安全生产模拟考试一点通公众号小程序 A特种设备相关管理(锅炉压力容器压力管道)模拟考试题参考答案及A特种设备相关管理(锅炉压力容器压力管道)考试试题解析是安全生产模拟考试一点通题库老师及A特种设备相关…

[Docker]八.Docker 容器跨主机通讯

一.跨主机通讯原理 在主机192.168.31.140上的docker0(172.17.0.0/16)中有一个容器mycentos( 172.17.0.2/16), 在主机192.168.31.81上的docker0(172.17.0.0/16)中有一个容器mycentos( 172.17.0.2/16),然后在主机192.168.31.140上ping主机192.168.31.81,发现ping不通要实现两个主…