PyTorch概述(二)---MNIST

news2024/10/5 18:23:33

NIST Special Database3

  • 具体指的是一个更大的特殊数据库3;
  • 该数据库的内容为手写数字黑白图片;
  • 该数据库由美国人口普查局的雇员手写

NIST Special Database1

  • 特殊数据库1;
  • 该数据库的内容为手写数字黑白图片;
  • 该数据库的图片由高中学生手写;

MNIST

  • MNIST 数据库:Modified National Institute of Standards and Technology 数据库
  • 是一个大的手写数字的集合;
  • 具有训练集60,000个;
  • 测试集10,000个;
  • 是NIST3和NIST1的子集;
  • 数字图片已经被居中,以固定的尺寸值标准化处理;
  • 原始的黑白两层图像被设置为20x20 像素大小,且保持宽高比;
  • 结果图像在标准化算法中的反走样技术的处理下包含灰度级图像;
  • 通过计算像素的质心,和平移操作,手写的数字被居中放置到尺寸为28X28的图片中;

MNIST 用法

transform=transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize([0,],[1,])])
trainset=torchvision.datasets.MNIST(root='./data',
                                        train=True,
                                        download=True,
                                        transform=transform)
trainloader=torch.utils.data.DataLoader(trainset,
                                        batch_size=32,
                                        shuffle=True,
                                        num_workers=2)
testset=torchvision.datasets.MNIST(root='./data',
                                       train=False,
                                       download=True,
                                        transform=transform)
testloader=torch.utils.data.DataLoader(testset,
                                        batch_size=32,
                                        shuffle=True,
                                        num_workers=2)

MNIST 源码(python)

import codecs
import os
import os.path
import shutil
import string
import sys
import warnings
from typing import Any,Callable,Dict,List,Optional,Tuple
from urllib.error import URLError

import numpy as np
import torch
from PIL import Image

from .utils import _flip_byte_order,check_integrity,download_and_extract_archive,extract_archive,verify_str_arg
from .vision import VisionDataset

class MNIST(VisionDataset):
    '''
    'MNIST <http://yann.lecun.com/exdb/mnist/>' _Dataset.
    '''
    mirrors=["http://yann.lecun.com/exdb/mnist/","https://ossci-datasets.s3.amazonaws.com/mnist/"]
    resource=[("train-images-idx3-ubyte.gz","f68b3c2dcbeaaa9fbdd348bbdeb94873"),
              ("train-labels-idx1-ubyte.gz","d53e105ee54ea40749a09fcbcd1e9432"),
              ("t10k-images-idx3-ubyte.gz","9fb629c4189551a2d022fa330f9573f3"),
              ("t10k-labels-idx1-ubyte.gz","ec29112dd5afa0611ce80d1b7f02629c")]
    training_file="training.pt"
    test_file="test.pt"
    classes=["0-zero",
             "1-one",
             "2-two",
             "3-three",
             "4-four",
             "5-five",
             "6-six",
             "7-seven",
             "8-eight",
             "9-nine"]
    @property
    def train_labels(self):
        warnings.warn("train_labels has been renamed targets")
        return self.targets
    @property
    def test_labels(self):
        warnings.warn("test_labels has been renamed targets")
        return self.targets
    @property
    def train_data(self):
        warnings.warn("train_data has been renamed data")
        return self.data
    @property
    def test_data(self):
        warnings.warn("test_data has been renamed data")
        return self.data
    def __init__(self,root:str,
                 train:bool=True,
                 transform:Optional[Callable]=None,
                 target_transform:Optional[Callable]=None,
                 download:bool=False)->None:
        '''
        Args
        :param root: string,root directory of dataset where 'MNIST/raw/train-images-idx3-ubyte' and 'MNIST/raw/t10k-images-idx3-ubyte' exist.
        :param train:(bool,optional),if true,creates dataset from 'train-images-idx3-utyte',otherwise from 't10k-images-idx3-utyte'.
        :param transform:(callable,optional),a function/transform that takes in an PIL image and returns a transformed version.E.g,'transform.RandomCrop'
        :param target_transform:(callable,optional),a function/transform that takes in the target and transform it.
        :param download:(bool,optional),if True,downloads the dataset from the internet and puts it in root directory.If dataset is already downloaded,it is not download again.
        '''
        super().__init__(root,transform,target_transform)
        self.train=train

        if self._check_legacy_exist():
            self.data,self.targets=self._load_legacy_data()
            return
        if download:
            self.download()
        if not self._check_exists():
            raise RuntimeError("Dataset not found.You can use download=True to download it")
        self.data,self.targets=self._load_data()

    def _check_legacy_exist(self):
        processed_folder_exists=os.path.exists(self.processed_folder)
        if not processed_folder_exists:
            return False
        return all(check_integrity(os.path.join(self.processed_folder,file)) for file in (self.training_file,self.test_file))
    def _load_legacy_data(self):
        #This is for BC only,We no longer cache the data in a custom binary,but simply read from the raw data directly.
        data_file=self.training_file if self.train else self.test_file
        return torch.load(os.path.join(self.processed_folder,data_file))
    def _load_data(self):
        image_file = f"{'train' if self.train else 't10k'}-images-idx3-ubyte"
        data = read_image_file(os.path.join(self.raw_folder, image_file))

        label_file = f"{'train' if self.train else 't10k'}-labels-idx1-ubyte"
        targets = read_label_file(os.path.join(self.raw_folder, label_file))

        return data, targets

    def __getitem__(self, index: int) -> Tuple[Any, Any]:
        """
        Args:
            index (int): Index

        Returns:
            tuple: (image, target) where target is index of the target class.
        """
        img, target = self.data[index], int(self.targets[index])

        # doing this so that it is consistent with all other datasets
        # to return a PIL Image
        img = Image.fromarray(img.numpy(), mode="L")

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target

    def __len__(self) -> int:
        return len(self.data)

    @property
    def raw_folder(self) -> str:
        return os.path.join(self.root, self.__class__.__name__, "raw")

    @property
    def processed_folder(self) -> str:
        return os.path.join(self.root, self.__class__.__name__, "processed")

    @property
    def class_to_idx(self) -> Dict[str, int]:
        return {_class: i for i, _class in enumerate(self.classes)}

    def _check_exists(self) -> bool:
        return all(
            check_integrity(os.path.join(self.raw_folder, os.path.splitext(os.path.basename(url))[0]))
            for url, _ in self.resources
        )

    def download(self) -> None:
        """Download the MNIST data if it doesn't exist already."""

        if self._check_exists():
            return

        os.makedirs(self.raw_folder, exist_ok=True)

        # download files
        for filename, md5 in self.resources:
            for mirror in self.mirrors:
                url = f"{mirror}{filename}"
                try:
                    print(f"Downloading {url}")
                    download_and_extract_archive(url, download_root=self.raw_folder, filename=filename, md5=md5)
                except URLError as error:
                    print(f"Failed to download (trying next):\n{error}")
                    continue
                finally:
                    print()
                break
            else:
                raise RuntimeError(f"Error downloading {filename}")

    def extra_repr(self) -> str:
        split = "Train" if self.train is True else "Test"
        return f"Split: {split}"

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1466954.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

GitCode配置ssh

下载SSH windows设置里选“应用” 选“可选功能” 添加功能 安装这个 坐等安装&#xff0c;安装好后可以关闭设置。 运行 打开cmd 执行如下指令&#xff0c;启动SSH服务。 net start sshd设置开机自启动 把OpenSSH服务添加到Windows自启动服务中&#xff0c;可避免每…

mysql的日志文件在哪?

阅读本文之前请参阅----MySQL 数据库安装教程详解&#xff08;linux系统和windows系统&#xff09; MySQL的日志文件通常包括错误日志、查询日志、慢查询日志和二进制日志等。这些日志文件的位置取决于MySQL的安装和配置。以下是一些常见的日志文件位置和如何找到它们&#xff…

【kubernetes】二进制部署k8s集群之,多master节点负载均衡以及高可用(下)

↑↑↑↑接上一篇继续部署↑↑↑↑ 之前已经完成了单master节点的部署&#xff0c;现在需要完成多master节点以及实现k8s集群的高可用 一、完成master02节点的初始化操作 二、在master01节点基础上&#xff0c;完成master02节点部署 步骤一&#xff1a;准备好master节点所需…

调用 Python 函数遗漏括号 ( )

调用 Python 函数遗漏括号 1. Example - error2. Example - correctionReferences 1. Example - error name "Forever Strong" print(name.upper()) print(name.lower)FOREVER STRONG <built-in method lower of str object at 0x0000000002310670>---------…

【ArcGIS】利用高程进行坡度分析:区域面/河道坡度

在ArcGIS中利用高程进行坡度分析 坡度ArcGIS实操案例1&#xff1a;流域面上坡度计算案例2&#xff1a;河道坡度计算2.1 案例数据2.2 操作步骤 参考 坡度 坡度是地表单元陡缓的程度&#xff0c;通常把坡面的垂直高度和水平距离的比值称为坡度。 坡度的表示方法有百分比法、度数…

单片机04__基本定时器__毫秒微秒延时

基本定时器__毫秒微秒延时 基本定时器介绍&#xff08;STM32F40x&#xff09; STM32F40X芯片一共包含14个定时器&#xff0c;这14个定时器分为3大类&#xff1a; 通用定时器 10个 TIM9-TIM1和TIM2-TIM5 具有基本定时器功能&#xff0c; 还具有输入捕获&#xff0c;输出比较功…

yarn install:unable to get local issuer certificate

一、问题描述 今天在Jenkins上发布项目时&#xff0c;遇到一个报错&#xff1a; error Error: unable to get local issuer certificateat TLSSocket.onConnectSecure (node:_tls_wrap:1535:34)at TLSSocket.emit (node:events:513:28)at TLSSocket._finishInit (node:_tls_w…

PLC_博图系列☞基本指令“取反RLO”

PLC_博图系列☞基本指令“取反RLO” 文章目录 PLC_博图系列☞基本指令“取反RLO”背景介绍取反RLO说明示例 关键字&#xff1a; PLC、 西门子、 博图、 Siemens 、 取反RLO 背景介绍 这是一篇关于PLC编程的文章&#xff0c;特别是关于西门子的博图软件。我并不是专业的PLC…

谷歌Gemma开源了

1、Gemma的表现 自从大模型横空出世之后&#xff0c;大部分大模型都是闭源的&#xff0c;只有少部分模型选择开源。谷歌推出了全新的开源模型系列Gemma&#xff0c;相比谷歌之前的 Gemini模型&#xff0c;Gemma 更加轻量&#xff0c;可以免费使用&#xff0c;模型权重也一并开…

详解编译和链接!

目录 1. 翻译环境和运行环境 2. 翻译环境 2.1 预处理 2.2 编译 2.3 汇编 2.4 链接 3. 运行环境 4.完结散花 悟已往之不谏&#xff0c;知来者犹可追 创作不易&#xff0c;宝子们&#xff01;如果这篇文章对你们…

Vue+SpringBoot打造开放实验室管理系统

目录 一、摘要1.1 项目介绍1.2 项目录屏 二、研究内容2.1 实验室类型模块2.2 实验室模块2.3 实验管理模块2.4 实验设备模块2.5 实验订单模块 三、系统设计3.1 用例设计3.2 数据库设计 四、系统展示五、样例代码5.1 查询实验室设备5.2 实验放号5.3 实验预定 六、免责说明 一、摘…

多窗口编程

六、多窗口编程 QMessageBox消息对话框&#xff08;掌握&#xff09; QMessageBox继承自QDialog&#xff0c;显示一个模态对话框。用于用户前台信息通知或询问用户问题&#xff0c;并接收问题答案。 QDialog的Qt源码中&#xff0c;派生类往往都是一些在特定场合下使用的预设好的…

【Vuforia+Unity】AR03-圆柱体物体识别(Cylinder Targets)

1.创建数据库模型 这个是让我们把生活中类似圆柱体和圆锥体的物体进行AR识别所选择的模型 Bottom Diameter:底部直径 Top Diameter:顶部直径 Side Length:圆柱侧面长度 请注意&#xff0c;您不必上传所有三个部分的图片&#xff0c;但您需要先为侧面曲面关联一个图像&#…

Threejs 实现3D影像地图,Json地图,地图下钻

1.使用threejs实现3D影像地图效果&#xff0c;整体效果看起来还可以&#xff0c;底层抽象了基类&#xff0c;实现了通用&#xff0c;对任意省份&#xff0c;城市都可以只替换数据&#xff0c;即可轻松实现效果。 效果如下&#xff1a; 链接https://www.bilibili.com/video/BV1…

[AutoSar]BSW_Com1 Can通信入门

目录 关键词平台说明一、车身CAN简介二、相关模块三、Can报文分类及信号流路径3.1 应用报文3.2 应用报文&#xff08;多路复用multiplexer&#xff09;3.3 诊断报文3.4 网络管理报文3.5 XCP报文&#xff08;标定报文&#xff09; 关键词 嵌入式、C语言、autosar、OS、BSW 平台…

分散的产品开发团队

分散的产品开发团队指的是各个团队或成员在地理位置上分布在不同地方&#xff0c;通过互联网和现代通讯技术进行协作和沟通&#xff0c;以共同完成产品开发任务的团队模式。 这种团队模式的优势在于可以充分利用各地的人才资源&#xff0c;降低团队的管理和协作成本&#xff0…

RobotGPT:利用ChatGPT的机器人操作学习框架,三星电子研究院与张建伟院士、孙富春教授、方斌教授合作发表RAL论文

1 引言 大型语言模型&#xff08;LLMs&#xff09;在文本生成、翻译和代码合成方面展示了令人印象深刻的能力。最近的工作集中在将LLMs&#xff0c;特别是ChatGPT&#xff0c;整合到机器人技术中&#xff0c;用于任务如零次系统规划。尽管取得了进展&#xff0c;LLMs在机器人技…

VSCODE使用Django 页面和渲染

https://code.visualstudio.com/docs/python/tutorial-django#_use-a-template-to-render-a-page 通过模板渲染页面 文件 实现步骤 1&#xff0c; 修改代码&#xff0c;hello的App名字增加到installed_apps表中。 2&#xff0c; hello子目录下&#xff0c;创建 .\templates\…

RT-Thread 时钟 timer delay 相关

前言 此处,介绍对delay 时钟 timer 这几部分之间的关联和相关的知识点;本来只是想介绍一下 delay的,但是发现说到delay 不先 提到 先验知识 晶振\时钟\时钟节拍\定时器 好像没法解释透彻,所以就变成了 晶振\时钟\时钟节拍\定时器\delay 的很简单的概括一遍;并附带上能直接运行的…

STM32物联网(封装AT指令进行TCP连接及数据的接收和发送)

文章目录 前言一、AT指令函数封装1.向ESP8266发送数据函数2.设置ESP8266工作模式3.连接WIFI函数4.查询IP地址5.连接TCP服务器6.发送数据到TCP服务器7.接收并解析来自TCP服务器的数据8.关闭TCP服务器 二、代码测试总结 前言 本篇文章将继续带大家学习STM32物联网&#xff0c;那…