空洞卷积网络实现

news2024/10/7 10:25:10

代码:





import torch.nn as nn
import numpy as np

from matplotlib import pyplot as plt
import time
#from utils import get_accur,load_data,train

import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
import torch
import torch.optim as optim
import numpy as np
def load_data(path, batch_size):
    datasets = torchvision.datasets.ImageFolder(
        root = path,
        transform = transforms.Compose([
            transforms.ToTensor()
        ])
    )

    dataloder = DataLoader(datasets, batch_size=batch_size, shuffle=True)
    return datasets,dataloder

def get_accur(preds, labels):
    preds = preds.argmax(dim=1)
    return torch.sum(preds == labels).item()

def train(model, epochs, learning_rate, dataloader, criterion, testdataloader):
    optimizer = optim.Adam(model.parameters(),lr=learning_rate)

    train_loss_list = []
    test_loss_list = []
    train_accur_list = []
    test_accur_list = []
    train_len = len(dataloader.dataset)
    test_len = len(testdataloader.dataset)

    for i in range(epochs):
        train_loss = 0.0
        train_accur = 0
        test_loss = 0.0
        test_accur = 0
        for batch in dataloader:
            imgs, labels = batch
            preds = model(imgs)
            optimizer.zero_grad()
            loss = criterion(preds, labels)

            loss.backward()
            optimizer.step()
            train_loss += loss.item()
            train_accur += get_accur(preds,labels)

        train_loss_list.append(train_loss)
        train_accur_list.append(train_accur / train_len)

        for batch in testdataloader:
            imgs, labels = batch
            preds = model(imgs)
            loss = criterion(preds, labels)
            test_loss += loss.item()
            test_accur += get_accur(preds,labels)

        test_loss_list.append(test_loss)
        test_accur_list.append(test_accur / test_len)

        print("epoch {} : train_loss : {}; train_accur : {}".format(i + 1, train_loss, train_accur / train_len))

    return np.array(train_accur_list), np.array(train_loss_list), np.array(test_accur_list), np.array(test_loss_list)

class ConvNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, stride=2, padding=0, dilation=1),
            nn.BatchNorm2d(32),
            nn.ReLU()
        )
        self.layer2 = nn.Sequential(
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5, stride=1, padding=0, dilation=2),
            nn.BatchNorm2d(64),
            nn.ReLU()
        )
        self.layer3 = nn.Sequential(
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=5, stride=1, padding=0, dilation=5),
            nn.BatchNorm2d(128),
            nn.ReLU()
        )
        self.fc = nn.Linear(128 * 3 * 3, 3)

    def forward(self, x):
        x = self.layer1(x)

        x = self.layer2(x)

        x = self.layer3(x)

        x = x.view(-1, 128 * 3 * 3)

        out = self.fc(x)

        return out


if __name__ == "__main__":
    train_path = "./cnn/train/"
    test_path = "./cnn/test/"
    train_datasets, train_dataloader = load_data(train_path, 64)
    test_datasets, test_dataloader = load_data(test_path, 64)
    model = ConvNetwork()
    critic = nn.CrossEntropyLoss()
    epoch = 15
    lr = 0.01
    start = time.clock()
    train_accur_list, train_loss_list, test_accur_list, test_loss_list = train(model, epoch, lr, train_dataloader,
                                                                               critic, test_dataloader)
    end = time.clock()
    test_accur = 0
    for batch in test_dataloader:
        imgs, labels = batch
        preds = model(imgs)
        test_accur += get_accur(preds, labels)

    print("Accuracy on test datasets : {}".format(test_accur / len(test_datasets)))
    print("Total time".format(end - start))
    x_axis = np.arange(1, epoch + 1)
    plt.plot(x_axis, train_loss_list, label="train loss")
    plt.plot(x_axis, test_loss_list, label="test loss")
    plt.legend()
    plt.show()

    plt.plot(x_axis, train_accur_list, label="train accur")
    plt.plot(x_axis, test_accur_list, label="test accur")
    plt.legend()
    plt.show()

执行结果:

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/904204.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

硬编码基础一(经典定长指令,寄存器相关)

硬编码基础一(定长指令) push/pop 通用寄存器 50~57是push8个32位通用寄存器 58~5f是pop8个32位通用寄存器 inc/dec 通用寄存器 40~47是inc8个32位通用寄存器 47~4f是dec8个32位通用寄存器 八位通用寄存器的立即数赋值 b0~b3 {立即数} 是低八位(…

蓝蓝设计-专业UI设计公司-界面设计作品

数慧时空(全称:北京数慧时空信息技术有限公司)是以空间信息技术为核心、国内领先的国土资源行业解决方案提供商,具有业务咨询、自主软件研发、数据加工和信息服务的全业务综合服务能力,是国土资源部最重要的信息化合作…

AD19基础应用技巧(位号的调整)

在进行元件装配时,需要输出相应的装配文件,而元件的位号图可以方便比对元件装配。隐藏其他层,只显示Overlay和Solder层可以更方便地进行位号调整。 一般来说,位号大都放到相应元件旁边,其调整应遵循以下原则&#xff…

LLM低成本微调方法

LLM日益流行,已经渗透到各个领域,比如生物医学,但是模型的规模导致微调LLM对普通用户不够友好,因此,我们需要借助一些低成本方法,通过更新少量参数也达到与LLM全参数更新一样的效果。这里介绍三种主流方法&…

JVM——垃圾回收(垃圾回收算法+分代垃圾回收+垃圾回收器)

1.如何判断对象可以回收 1.1引用计数法 只要一个对象被其他对象所引用,就要让该对象的技术加1,某个对象不再引用其,则让它计数减1。当计数变为0时就可以作为垃圾被回收。 有一个弊端叫做循环引用,两个的引用计数都是1&#xff…

npm和node版本升级教程

cmd中查看本地安装的node版本 node -v //查询node的位置 where node2.官网下载所需要的node版本,安装在刚查出来的文件夹下,即覆盖掉原来的版本 3.查看node版本是否已经更新 4.查看npm版本是否和node版本相匹配 cnpm install -g npm

超实用的两款截图工具(FastStone Capture 和 Snipaste)

文章目录 一、概述1)FastStone Capture2)Snipaste 二、FastStone Capture 和 Snipaste 截图软件安装 一、概述 "FastStone Capture" 和 "Snipaste" 都是计算机上常用的截图工具,用于捕捉屏幕截图、编辑图像以及进行屏幕…

湘潭大学 湘大 XTU OJ 1217 A+B VII 题解(非常详细)

链接 1217 题目 题目描述 小明非常高兴你能帮他处理那些罗马数字,他想学着自己写罗马数字,但是他不知道自己到底写对了没有。 请你帮他写个程序,能正确地将10进制数字转换成罗马数字,以便他能进行核对。 罗马数字是使用字母组…

Docker运行Nacos容器,过一会就报错`UnsatisfiedDependencyException`

Docker运行Nacos容器,过一会就报错UnsatisfiedDependencyException 问题背景: 最近要上线一个项目,由于要使用Nacos作为服务注册中心,为了方便,我就打算直接使用Docker部署Nacos,没想到Nacos启动没一会就嗝…

三、Kafka生产者

目录 3.1 生产者消息发送流程3.1.1 发送原理 3.2 异步发送 API3.3 同步发送数据3.4 生产者分区3.4.1 kafka分区的好处3.4.2 生产者发送消息的分区策略3.4.3 自定义分区器 3.5 生产者如何提高吞吐量3.6 数据可靠性 3.1 生产者消息发送流程 3.1.1 发送原理 3.2 异步发送 API 3…

SAP MM学习笔记26- SAP中 振替转记(转移过账)和 在库转送(库存转储)2- 品目Code振替转记 和 在库转送

SAP 中在库移动 不仅有入库(GR),出库(GI),也可以是单纯内部的转记或转送。 1,振替转记(转移过账) 2,在库转送(库存转储) 1&#xff…

代码部署到服务器

前言:相信看到这篇文章的小伙伴都或多或少有一些编程基础,懂得一些linux的基本命令了吧,本篇文章将带领大家服务器如何部署一个使用django框架开发的一个网站进行云服务器端的部署。 文章使用到的的工具 Python:一种编程语言&…

阿里云无影云电脑/云桌面收费价格表_使用申请方法

阿里云无影云电脑配置具体收费价格表,4核8G企业办公型云电脑可以免费使用3个月,无影云电脑地域不同价格不同,无影云电脑费用是由云桌面配置、云盘、互联网访问带宽、AD Connector 、桌面组共用桌面session 等费用组成,阿里云百科分…

Linux驱动开发(Day5)

思维导图: 不同设备号文件绑定:

springboot sl4j2 写入日志到mysql

问题描述 springboot初始化的时候,会先初始化日志然后再加载数据源如果用配置文件进行初始化,那么会出现数据源没有加载成功,导致空指针异常 报错排查如下: 搜索报错信息,OBjects.invoke is Null打断点发现。dataso…

简历本-专业在线简历制作下载网站 自带智能简历诊断

简历本是一个高效的在线简历制作与管理工具,为求职者提供专业简历模板,使用简历本5分钟就能制作一份优秀简历,可随时随地将简历下载为Word、PDF、图片格式文件,可在线发送或投递,不过使用需要注册登陆,提供…

【NLP】1、BERT | 双向 transformer 预训练语言模型

文章目录 一、背景二、方法 论文:BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding 出处:Google 一、背景 在 BERT 之前的语言模型如 GPT 都是单向的模型,但 BERT 认为虽然单向(从左到右预测…

【机器学习实战】朴素贝叶斯:过滤垃圾邮件

【机器学习实战】朴素贝叶斯:过滤垃圾邮件 0.收集数据 这里采用的数据集是《机器学习实战》提供的邮件文件,该文件有ham 和 spam 两个文件夹,每个文件夹中各有25条邮件,分别代表着 正常邮件 和 垃圾邮件。 这里需要注意的是需要…

Brain:背内侧前额叶/背侧前扣带皮层(dmPFC/dACC)的相关争议

摘要 背内侧前额叶皮层/背侧前扣带皮层(dmPFC/dACC)是一个功能存在诸多理论和争议的脑区。甚至其精确的解剖边界也饱受争议。在过去的几十年里,dmPFC/dACC与15种以上的认知过程相关联,这些过程有时看起来完全无关(例如,身体感知、认知冲突)。…

c++优先级队列的模拟实现代码

了解: 1.优先队列是一种容器适配器,根据严格的弱排序标准,它的第一个元素总是它所包含的元素中最大的。 2. 类似于堆,在堆中可以随时插入元素,并且只能检索最大堆元素(优先队列中位于顶部的元素)。 3. 优先队列被实现为…