Pytorch深度学习—FashionMNIST数据集训练

news2024/12/28 3:47:51

文章目录

    • FashionMNIST数据集
    • 需求库导入、数据迭代器生成
    • 设备选择
    • 样例图片展示
    • 日志写入
    • 评估—计数器
    • 模型构建
    • 训练函数
    • 整体代码
    • 训练过程
    • 日志

FashionMNIST数据集

  • FashionMNIST(时尚 MNIST)是一个用于图像分类的数据集,旨在替代传统的手写数字MNIST数据集。它由 Zalando Research 创建,适用于深度学习和计算机视觉的实验。
    • FashionMNIST 包含 10 个类别,分别对应不同的时尚物品。这些类别包括 T恤/上衣、裤子、套头衫、裙子、外套、凉鞋、衬衫、运动鞋、包和踝靴。
    • 每个类别有 6,000 张训练图像和 1,000 张测试图像,总计 70,000 张图像。
    • 每张图像的尺寸为 28x28 像素,与MNIST数据集相同。
    • 数据集中的每个图像都是灰度图像,像素值在0到255之间。
      在这里插入图片描述

需求库导入、数据迭代器生成

import os
import random
import numpy as np
import datetime
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

import torchvision
from torchvision import transforms

import argparse
from tqdm import tqdm

import matplotlib.pyplot as plt
from torch.utils.tensorboard import SummaryWriter


def _load_data():
    """download the data, and generate the dataloader"""
    trans = transforms.Compose([transforms.ToTensor()])

    train_dataset = torchvision.datasets.FashionMNIST(root='./data/', train=True, download=True, transform=trans)
    test_dataset = torchvision.datasets.FashionMNIST(root='./data/', train=False, download=True, transform=trans)
    # print(len(train_dataset), len(test_dataset))
    train_loader = DataLoader(train_dataset, shuffle=True, batch_size=args.batch_size, num_workers=args.num_works)
    test_loader = DataLoader(test_dataset, shuffle=True, batch_size=args.batch_size, num_workers=args.num_works)

    return (train_loader, test_loader)

设备选择

def _device():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    return device

样例图片展示

"""display data examples"""
def _image_label(labels):
    text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat',
                  'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
    return [text_labels[int(i)] for i in labels]


def _show_images(imgs, rows, columns, titles=None, scale=1.5):
    figsize = (rows * scale, columns * 1.5)
    fig, axes = plt.subplots(rows, columns, figsize=figsize)
    axes = axes.flatten()
    for i, (img, ax) in enumerate(zip(imgs, axes)):
        ax.imshow(img)
        ax.axes.get_xaxis().set_visible(False)
        ax.axes.get_yaxis().set_visible(False)
        if titles:
            ax.set_title(titles[i])
    plt.show()
    return axes

def _show_examples():
    train_loader, test_loader = _load_data()

    for images, labels in train_loader:
        images = images.squeeze(1)
        _show_images(images, 3, 3, _image_label(labels))
        break

日志写入

class _logger():
    def __init__(self, log_dir, log_history=True):
        if log_history:
            log_dir = os.path.join(log_dir, datetime.datetime.now().strftime("%Y_%m_%d__%H_%M_%S"))
        self.summary = SummaryWriter(log_dir)

    def scalar_summary(self, tag, value, step):
        self.summary.add_scalars(tag, value, step)

    def images_summary(self, tag, image_tensor, step):
        self.summary.add_images(tag, image_tensor, step)

    def figure_summary(self, tag, figure, step):
        self.summary.add_figure(tag, figure, step)

    def graph_summary(self, model):
        self.summary.add_graph(model)

    def close(self):
        self.summary.close()

评估—计数器

class AverageMeter():
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

模型构建

class Conv3x3(nn.Module):
    def __init__(self, in_channels, out_channels, down_sample=False):
        super(Conv3x3, self).__init__()
        self.conv = nn.Sequential(nn.Conv2d(in_channels, out_channels, 3, 1, 1),
                                  nn.BatchNorm2d(out_channels),
                                  nn.ReLU(inplace=True),
                                  nn.Conv2d(out_channels, out_channels, 3, 1, 1),
                                  nn.BatchNorm2d(out_channels),
                                  nn.ReLU(inplace=True))
        if down_sample:
            self.conv[3] = nn.Conv2d(out_channels, out_channels, 2, 2, 0)

    def forward(self, x):
        return self.conv(x)

class SimpleNet(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(SimpleNet, self).__init__()
        self.conv1 = Conv3x3(in_channels, 32)
        self.conv2 = Conv3x3(32, 64, down_sample=True)
        self.conv3 = Conv3x3(64, 128)
        self.conv4 = Conv3x3(128, 256, down_sample=True)
        self.fc = nn.Linear(256*7*7, out_channels)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)

        x = torch.flatten(x, 1)
        out = self.fc(x)
        return out

训练函数

def train(model, train_loader, test_loader, criterion, optimizor, epochs, device, writer, save_weight=False):
    train_loss = AverageMeter()
    test_loss = AverageMeter()
    train_precision = AverageMeter()
    test_precision = AverageMeter()

    time_tick = datetime.datetime.now().strftime("%Y_%m_%d__%H_%M_%S")

    for epoch in range(epochs):
        print('\nEpoch: [%d | %d] LR: %f' % (epoch + 1, args.epochs, args.lr))
        model.train()
        for input, label in tqdm(train_loader):
            input, label = input.to(device), label.to(device)
            output = model(input)
            # backward
            loss = criterion(output, label)
            optimizor.zero_grad()
            loss.backward()
            optimizor.step()

            # logger
            predict = torch.argmax(output, dim=1)
            train_pre = sum(predict == label) / len(label)
            train_loss.update(loss.item(), input.size(0))
            train_precision.update(train_pre.item(), input.size(0))

        model.eval()
        with torch.no_grad():
            for X, y in tqdm(test_loader):
                X, y = X.to(device), y.to(device)
                y_hat = model(X)

                loss_te = criterion(y_hat, y)
                predict_ = torch.argmax(y_hat, dim=1)
                test_pre = sum(predict_ == y) / len(y)

                test_loss.update(loss_te.item(), X.size(0))
                test_precision.update(test_pre.item(), X.size(0))

        if save_weight:
            best_dice = args.best_dice
            weight_dir = os.path.join(args.weight_dir, args.model, time_tick)
            os.makedirs(weight_dir, exist_ok=True)

            monitor_dice = test_precision.avg
            if monitor_dice > best_dice:
                best_dice = max(monitor_dice, best_dice)

                name = os.path.join(weight_dir, args.model + '_' + str(epoch) + \
                       '_test_loss-' + str(round(test_loss.avg, 4)) + \
                       '_test_dice-' + str(round(best_dice, 4)) + '.pt')
                torch.save(model.state_dict(), name)

        print("train" + '---Loss: {loss:.4f} | Dice: {dice:.4f}'.format(loss=train_loss.avg, dice=train_precision.avg))
        print("test " + '---Loss: {loss:.4f} | Dice: {dice:.4f}'.format(loss=test_loss.avg, dice=test_precision.avg))

        # summary
        writer.scalar_summary("Loss/loss", {"train": train_loss.avg, "test": test_loss.avg}, epoch)
        writer.scalar_summary("Loss/precision", {"train": train_precision.avg, "test": test_precision.avg}, epoch)

        writer.close()

整体代码

import os
import random
import numpy as np
import datetime
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

import torchvision
from torchvision import transforms

import argparse
from tqdm import tqdm

import matplotlib.pyplot as plt
from torch.utils.tensorboard import SummaryWriter

"""Reproduction experiment"""
def setup_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    # torch.backends.cudnn.benchmark = False
    # torch.backends.cudnn.enabled = False
    # torch.backends.cudnn.deterministic = True


"""data related"""
def _base_options():
    parser = argparse.ArgumentParser(description="Train setting for FashionMNIST")
    # about dataset
    parser.add_argument('--batch_size', default=8, type=int, help='the batch size of dataset')
    parser.add_argument('--num_works', default=4, type=int, help="the num_works used")
    # train
    parser.add_argument('--epochs', default=100, type=int, help='train iterations')
    parser.add_argument('--lr', default=0.001, type=float, help='learning rate')
    parser.add_argument('--model', default="SimpleNet", choices=["SimpleNet"], help="the model choosed")
    # log dir
    parser.add_argument('--log_dir', default="./logger/", help='the path of log file')
    #
    parser.add_argument('--best_dice', default=-100, type=int, help='for save weight')
    parser.add_argument('--weight_dir', default="./weight/", help='the dir for save weight')

    args = parser.parse_args()
    return args

def _load_data():
    """download the data, and generate the dataloader"""
    trans = transforms.Compose([transforms.ToTensor()])

    train_dataset = torchvision.datasets.FashionMNIST(root='./data/', train=True, download=True, transform=trans)
    test_dataset = torchvision.datasets.FashionMNIST(root='./data/', train=False, download=True, transform=trans)
    # print(len(train_dataset), len(test_dataset))
    train_loader = DataLoader(train_dataset, shuffle=True, batch_size=args.batch_size, num_workers=args.num_works)
    test_loader = DataLoader(test_dataset, shuffle=True, batch_size=args.batch_size, num_workers=args.num_works)

    return (train_loader, test_loader)

def _device():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    return device

"""display data examples"""
def _image_label(labels):
    text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat',
                  'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
    return [text_labels[int(i)] for i in labels]


def _show_images(imgs, rows, columns, titles=None, scale=1.5):
    figsize = (rows * scale, columns * 1.5)
    fig, axes = plt.subplots(rows, columns, figsize=figsize)
    axes = axes.flatten()
    for i, (img, ax) in enumerate(zip(imgs, axes)):
        ax.imshow(img)
        ax.axes.get_xaxis().set_visible(False)
        ax.axes.get_yaxis().set_visible(False)
        if titles:
            ax.set_title(titles[i])
    plt.show()
    return axes

def _show_examples():
    train_loader, test_loader = _load_data()

    for images, labels in train_loader:
        images = images.squeeze(1)
        _show_images(images, 3, 3, _image_label(labels))
        break

"""log"""
class _logger():
    def __init__(self, log_dir, log_history=True):
        if log_history:
            log_dir = os.path.join(log_dir, datetime.datetime.now().strftime("%Y_%m_%d__%H_%M_%S"))
        self.summary = SummaryWriter(log_dir)

    def scalar_summary(self, tag, value, step):
        self.summary.add_scalars(tag, value, step)

    def images_summary(self, tag, image_tensor, step):
        self.summary.add_images(tag, image_tensor, step)

    def figure_summary(self, tag, figure, step):
        self.summary.add_figure(tag, figure, step)

    def graph_summary(self, model):
        self.summary.add_graph(model)

    def close(self):
        self.summary.close()

"""evaluate the result"""
class AverageMeter():
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


"""define the Net"""
class Conv3x3(nn.Module):
    def __init__(self, in_channels, out_channels, down_sample=False):
        super(Conv3x3, self).__init__()
        self.conv = nn.Sequential(nn.Conv2d(in_channels, out_channels, 3, 1, 1),
                                  nn.BatchNorm2d(out_channels),
                                  nn.ReLU(inplace=True),
                                  nn.Conv2d(out_channels, out_channels, 3, 1, 1),
                                  nn.BatchNorm2d(out_channels),
                                  nn.ReLU(inplace=True))
        if down_sample:
            self.conv[3] = nn.Conv2d(out_channels, out_channels, 2, 2, 0)

    def forward(self, x):
        return self.conv(x)

class SimpleNet(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(SimpleNet, self).__init__()
        self.conv1 = Conv3x3(in_channels, 32)
        self.conv2 = Conv3x3(32, 64, down_sample=True)
        self.conv3 = Conv3x3(64, 128)
        self.conv4 = Conv3x3(128, 256, down_sample=True)
        self.fc = nn.Linear(256*7*7, out_channels)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)

        x = torch.flatten(x, 1)
        out = self.fc(x)
        return out

"""progress of train/test"""
def train(model, train_loader, test_loader, criterion, optimizor, epochs, device, writer, save_weight=False):
    train_loss = AverageMeter()
    test_loss = AverageMeter()
    train_precision = AverageMeter()
    test_precision = AverageMeter()

    time_tick = datetime.datetime.now().strftime("%Y_%m_%d__%H_%M_%S")

    for epoch in range(epochs):
        print('\nEpoch: [%d | %d] LR: %f' % (epoch + 1, args.epochs, args.lr))
        model.train()
        for input, label in tqdm(train_loader):
            input, label = input.to(device), label.to(device)
            output = model(input)
            # backward
            loss = criterion(output, label)
            optimizor.zero_grad()
            loss.backward()
            optimizor.step()

            # logger
            predict = torch.argmax(output, dim=1)
            train_pre = sum(predict == label) / len(label)
            train_loss.update(loss.item(), input.size(0))
            train_precision.update(train_pre.item(), input.size(0))

        model.eval()
        with torch.no_grad():
            for X, y in tqdm(test_loader):
                X, y = X.to(device), y.to(device)
                y_hat = model(X)

                loss_te = criterion(y_hat, y)
                predict_ = torch.argmax(y_hat, dim=1)
                test_pre = sum(predict_ == y) / len(y)

                test_loss.update(loss_te.item(), X.size(0))
                test_precision.update(test_pre.item(), X.size(0))

        if save_weight:
            best_dice = args.best_dice
            weight_dir = os.path.join(args.weight_dir, args.model, time_tick)
            os.makedirs(weight_dir, exist_ok=True)

            monitor_dice = test_precision.avg
            if monitor_dice > best_dice:
                best_dice = max(monitor_dice, best_dice)

                name = os.path.join(weight_dir, args.model + '_' + str(epoch) + \
                       '_test_loss-' + str(round(test_loss.avg, 4)) + \
                       '_test_dice-' + str(round(best_dice, 4)) + '.pt')
                torch.save(model.state_dict(), name)

        print("train" + '---Loss: {loss:.4f} | Dice: {dice:.4f}'.format(loss=train_loss.avg, dice=train_precision.avg))
        print("test " + '---Loss: {loss:.4f} | Dice: {dice:.4f}'.format(loss=test_loss.avg, dice=test_precision.avg))

        # summary
        writer.scalar_summary("Loss/loss", {"train": train_loss.avg, "test": test_loss.avg}, epoch)
        writer.scalar_summary("Loss/precision", {"train": train_precision.avg, "test": test_precision.avg}, epoch)

        writer.close()




if __name__ == "__main__":
    # config
    args = _base_options()
    device = _device()
    # data
    train_loader, test_loader = _load_data()
    # logger
    writer = _logger(log_dir=os.path.join(args.log_dir, args.model))
    # model
    model = SimpleNet(in_channels=1, out_channels=10).to(device)
    optimizor = torch.optim.Adam(model.parameters(), lr=args.lr)
    criterion = nn.CrossEntropyLoss()

    train(model, train_loader, test_loader, criterion, optimizor, args.epochs, device, writer, save_weight=True)


"""    
    args = _base_options()
    _show_examples()  # ———>  样例图片显示
"""

训练过程

在这里插入图片描述

日志

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1084976.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

各种小程序/PC/移动端修改代码/文件后, 不热重载问题修复

各种框架有各种的配置,首先检查是否配置了热更新, 打个比方,taro中底版本中配置热更新是在配置文件中开启mini.hot,高版本中,终端执行打包命令的时候,添加–watch参数, 如果以上检查都没有问题,依旧还是没有热更新,那么很有可能是依赖出了问题,可以通过以下方法,重新…

Spring 趣玩

1、修改控制台启动显示的图案 在SpringBoot项目的resources目录下新建一个banner.txt文本文件,然后将启动Banner粘贴到此文本文件中,启动项目即可。 在线制作banner ①、http://patorjk.com/software/taag/ ②、https://www.bootschool.net/ascii ③、ht…

mysql数据库创建及用户添加和权限管理

1、创建数据库: CREATE DATABASE database_name; 例如: CREATE DATABASE mydatabase; 2、创建用户: CREATE USER usernamehostname; 例如: CREATE USER myuserlocalhost; 注意:替换 username 为你想要创建的用户名&a…

香港云服务器使用的小误区

​  当前,在海外市场的发展下,香港云服务器被推的火热。一方面,您可以根据需要积极利用它的免备案和国际线路等特性,另一方面,也可以借助它,使用尽可能多或尽可能少的存储空间,您的业务也可以…

高校教务系统登录页面JS分析——合肥工业大学

高校教务系统密码加密逻辑及JS逆向 本文将介绍高校教务系统的密码加密逻辑以及使用JavaScript进行逆向分析的过程。通过本文,你将了解到密码加密的基本概念、常用加密算法以及如何通过逆向分析来破解密码。 本文仅供交流学习,勿用于非法用途。 一、密码加…

干货分享|超全项目管理流程PPT汇总

我是胖圆,欢迎大家关注留言~ 或者移步公众号【胖圆说PM】找我~

linux放开8080端口

linux放开8080端口 输入命令: 查看8080端口是否开放 firewall-cmd --query-port8080/tcpno显示端口未打开,yes表示开启,linux开启防火墙默认关闭8080端口 这里是引用:https://blog.csdn.net/weixin_54067866/article/details/1…

虹科AR VIP研讨会 | 数字世界,视觉无界,诚邀您前来体验!

文章来源:虹科数字化AR 点击阅读原文:https://mp.weixin.qq.com/s/Q1YbpD0Mxq-KctOMALM0AA 点击链接报名:https://mp.weixin.qq.com/s/Q1YbpD0Mxq-KctOMALM0AA 主题速览 01 医疗保健的未来趋势:透过智能眼镜成像技术改善微创手术…

WorkPlus AI智能助理,基于GPT为企业提供专属的私有化部署解决方案

在当今数字时代,优质的客户服务是企业取得成功的重要因素之一。随着人工智能技术的不断发展,私有化部署AI智能客服成为企业提高客户体验、提升服务效率的新途径。WorkPlus作为领先的品牌,专注于提供可信赖的私有化部署解决方案,助…

物联网23届毕业了不想干Java转嵌入式可行吗?

物联网23届毕业了不想干Java转嵌入式可行吗? 可以的,现在嵌入式物联网的就业前景是比较不错的,物联网和嵌入式是相结合的,题主是电子信息工程专业的,小谷建议你学习嵌入式物联网,其覆盖的范围还是比较广的&#xff0c…

day04_方法_数组

今日内容 方法数组基础 复习 0 if(){ }else{ } if(){ }else if(){ }else if(){ }else{ } 1 循环的四要素 初始值,控制条件,迭代,循环体 2 迭代什么意思 迭代: 一次次的变化 3 for循环执行流程 (小红旗) for(int i 1;i < 11;i){ } 4 break,continue,return分别什么作用 brea…

正规的股票杠杆公司排名:安全的五大加杠杆的股票平台排行一览

在股票市场中&#xff0c;利用杠杆效应可以放大投资者的收益。然而&#xff0c;选择一个正规、可靠的杠杆公司至关重要。本篇文章将根据配查信、尚红网、倍悦网、兴盛网、诚利和、嘉正网、广瑞网、富灯网、天创网、恒正网、创通网等媒体提供的信息&#xff0c;分析并整理出正规…

我们用i.MX8M Plus开发了一个人工智能机器人小车,欢迎围观~

i.MX8M Plus的人工智能机器小车功能 AGV 小车是基于 i.MX8M Plus 为主控实现的一款双驱差速小车。从上到下由摄像头、舵机控制板、舵机、i.MX8M Plus 核心板与底板、电池、电机、轮子等组成。 i.MX8M Plus 有以下 5 个职责&#xff1a; 控制小车电机&#xff0c;负责控制小车…

智能洗衣管理系统中的RFID技术应用

一、行业背景 当前&#xff0c;酒店、医院、浴场以及专业洗涤公司面临着大量工作服和布草的交接、洗涤、熨烫、整理和储存等工序&#xff0c;如何有效地跟踪管理每一件布草的洗涤过程、洗涤次数、库存状态和布草归类等成为了一个巨大的挑战&#xff1a; 1、传统的纸面洗涤任务…

CentOS 7 基于C 连接ZooKeeper 客户端

前提条件&#xff1a;CentOS 7 编译ZooKeeper 客户端&#xff0c;请参考&#xff1a;CentOS 7 编译ZooKeeper 客户端 1、Docker 安装ZooKeeper # docker 获取zookeeper 最新版本 docker pull zookeeper# docker 容器包含镜像查看 docker iamges# 准备zookeeper 镜像文件挂载对…

相机噪声评估

当拥有一个相机&#xff0c;并且写了一个降噪的算法&#xff0c;想要测试降噪的应用效果。 相机在光线不足的情况下产生噪点的原因主要与以下几个因素有关&#xff1a; 感光元件的工作原理&#xff1a;相机的图像传感器是由数百万甚至数千万的感光元件&#xff08;如CMOS或CC…

论文研读|Robust Watermarking of Neural Network with Exponential Weighting

目录 论文信息文章简介研究动机查询修改攻击Auto-Encoder训练水印图像检测检测结果 水印图像重构 研究方法水印生成水印嵌入版权验证 实验结果保真度&#xff08;Fidelity&#xff09;有效性&#xff08;Effectiveness&#xff09;&鲁棒性&#xff08;Robustness&#xff0…

Apipost自动化测试

Apipost提供可视化的API自动化测试功能&#xff0c;使用Apipost研发人员可以设计、调试接口&#xff0c;测试人员可以基于同一数据源进行测试&#xff0c;Apipost 接口自动化功能在上次更新中进行了逻辑调整&#xff0c;带来更好的交互操作、更多的控制器选择&#xff0c;同时新…

【数据结构】深入探讨二叉树的遍历和分治思想(一)

&#x1f6a9;纸上得来终觉浅&#xff0c; 绝知此事要躬行。 &#x1f31f;主页&#xff1a;June-Frost &#x1f680;专栏&#xff1a;数据结构 &#x1f525;该文章主要讲述二叉树的递归结构及分治算法的思想。 目录&#xff1a; &#x1f30d;前言&#xff1a;&#x1f30d;…

使用react-router-dom在新标签页打开链接,而不是本页跳转

一般单页面应用&#xff0c;当你使用useNavigate时候的时候&#xff0c;用useNavigate来跳转&#xff0c;只能是在当前页面刷新跳转的&#xff0c;要想单独在一个tab页打开新页面&#xff0c;大概用三种方式。 第一种 使用link标签&#xff0c;配合target实现 <Link to&q…