Pytorch深度学习实践笔记8(b站刘二大人)

news2024/9/22 21:28:22

🎬个人简介:一个全栈工程师的升级之路!
📋个人专栏:pytorch深度学习
🎀CSDN主页 发狂的小花
🌄人生秘诀:学习的本质就是极致重复!

《PyTorch深度学习实践》完结合集_哔哩哔哩_bilibili​

目录

1 Pytorch 数据加载

2 Dataset和DataLoader

3 程序


1 Pytorch 数据加载

  • epoch、Batch-size 、iteration


例如下图:
8个样本、shuffle是打乱样本的顺序,Batch-szie为2,iteration 就是 8 / 2 为4,epoch是训练集进行几个轮次的迭代。

 




2 Dataset和DataLoader

 




Dataset 是一个抽象类,使用时必须进行重写,from 在torch.utils.data Dataset
(1)重写时,需要根据数据来进行构造__init__(self,filepath)
(2)__getitem__(self,index)用来让数据可以进行索引操作
(3)__len__(self)用来获取数据集的大小
DataLoader 用来加载数据为mini-Batch ,支持Batch-size 的设置,shuffle支持数据的打乱顺序。

  • 参数说明:
from torch.utils.data import DataLoader
 
test_load = DataLoader(dataset=test_data, batch_size=4 , shuffle= True, num_workers=0,drop_last=False)


batch_size=4表示每次取四个数据
shuffle= True表示开启数据集随机重排,即每次取完数据之后,打乱剩余数据的顺序,然后再进行下一次取
num_workers=0表示在主进程中加载数据而不使用任何额外的子进程,如果大于0,表示开启多个进程,进程越多,处理数据的速度越快,但是会使电脑性能下降,占用更多的内存
drop_last=False表示不丢弃最后一个批次,假设我数据集有10个数据,我的batch_size=3,即每次取三个数据,那么我最后一次只有一个数据能取,如果设置为true,则不丢弃这个包含1个数据的子集数据,反之则丢弃

 

  • 数据转换为dataset形式,进行DataLoader的使用
x_data = torch.tensor([[1.0],[2.0],[3.0],[4.0],[5.0],[6.0],[7.0],[8.0],[9.0]])
y_data = torch.tensor([[2.0],[4.0],[6.0],[8.0],[10.0],[12.0],[14.0],[16.0],[18.0]])

dataset = Data.TensorDataset(x_data,y_data)

loader = Data.DataLoader(  
    dataset=dataset,  
    batch_size=BATCH_SIZE,  
    shuffle=True,  
    num_workers=0  
)

pytorch中的DataLoader_pytorch dataloader-CSDN博客​


3 程序


数据分为训练集和测试集:Adam 训练

import torch
import numpy as np
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split

import matplotlib.pyplot as plt
 
 
# 读取原始数据,并划分训练集和测试集
raw_data = np.loadtxt('./dataset/diabetes.csv.gz', delimiter=',', dtype=np.float32)
X = raw_data[:, :-1]
Y = raw_data[:, [-1]]
Xtrain, Xtest, Ytrain, Ytest = train_test_split(X,Y,test_size=0.1)
Xtest = torch.from_numpy(Xtest)
Ytest = torch.from_numpy(Ytest)
 
# 将训练数据集进行批量处理
# prepare dataset
 
class DiabetesDataset(Dataset):
    def __init__(self, data,label):
 
        self.len = data.shape[0] # shape(多少行,多少列)
        self.x_data = torch.from_numpy(data)
        self.y_data = torch.from_numpy(label)
 
    def __getitem__(self, index):
        return self.x_data[index], self.y_data[index]
 
    def __len__(self):
        return self.len
 
 
train_dataset = DiabetesDataset(Xtrain,Ytrain)
train_loader = DataLoader(dataset=train_dataset, batch_size=16, shuffle=True, num_workers=0) #num_workers 多线程
 
# design model using class
 
 
class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.linear1 = torch.nn.Linear(8, 6)
        self.linear2 = torch.nn.Linear(6, 4)
        self.linear3 = torch.nn.Linear(4, 2)
        self.linear4 = torch.nn.Linear(2, 1)
        self.sigmoid = torch.nn.Sigmoid()
 
    def forward(self, x):
        x = self.sigmoid(self.linear1(x))
        x = self.sigmoid(self.linear2(x))
        x = self.sigmoid(self.linear3(x))
        x = self.sigmoid(self.linear4(x))
        return x
 
 
model = Model()
 
# construct loss and optimizer
criterion = torch.nn.BCELoss(reduction='mean')
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
 
epoch_list = []
loss_list = []

# training cycle forward, backward, update
def train(epoch):
    for i, data in enumerate(train_loader, 0):
        inputs, labels = data
        y_pred = model(inputs)
 
        loss = criterion(y_pred, labels)
 
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        return loss.item()

 
 
def test():
    with torch.no_grad():
        y_pred = model(Xtest)
        y_pred_label = torch.where(y_pred>=0.5,torch.tensor([1.0]),torch.tensor([0.0]))
        acc = torch.eq(y_pred_label, Ytest).sum().item() / Ytest.size(0)
        print("test acc:", acc)
 
if __name__ == '__main__':
    for epoch in range(10000):
        loss_val = train(epoch)
        print("epoch: ",epoch," loss: ",loss_val)
        epoch_list.append(epoch)
        loss_list.append(loss_val)
    test()

    plt.plot(epoch_list,loss_list)
    plt.title("Adam")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.savefig("./data/pytorch7_1.png")



简单的程序
 

import torch
import numpy as np
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
 
# prepare dataset
 
 
class DiabetesDataset(Dataset):
    def __init__(self, filepath):
        xy = np.loadtxt(filepath, delimiter=',', dtype=np.float32)
        self.len = xy.shape[0] # shape(多少行,多少列)
        self.x_data = torch.from_numpy(xy[:, :-1])
        self.y_data = torch.from_numpy(xy[:, [-1]])
 
    def __getitem__(self, index):
        return self.x_data[index], self.y_data[index]
 
    def __len__(self):
        return self.len
 
 
dataset = DiabetesDataset('./dataset/diabetes.csv.gz')
train_loader = DataLoader(dataset=dataset, batch_size=32, shuffle=True, num_workers=0) #num_workers 多线程
 
 
# design model using class
 
 
class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.linear1 = torch.nn.Linear(8, 6)
        self.linear2 = torch.nn.Linear(6, 4)
        self.linear3 = torch.nn.Linear(4, 1)
        self.sigmoid = torch.nn.Sigmoid()
 
    def forward(self, x):
        x = self.sigmoid(self.linear1(x))
        x = self.sigmoid(self.linear2(x))
        x = self.sigmoid(self.linear3(x))
        return x
 
 
model = Model()
 
# construct loss and optimizer
criterion = torch.nn.BCELoss(reduction='mean')
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
 
# training cycle forward, backward, update
if __name__ == '__main__':
    for epoch in range(100):
        for i, data in enumerate(train_loader, 0): # train_loader 是先shuffle后mini_batch
            inputs, labels = data
            y_pred = model(inputs)
            loss = criterion(y_pred, labels)
            print(epoch, i, loss.item())
 
            optimizer.zero_grad()
            loss.backward()
 
            optimizer.step()

🌈我的分享也就到此结束啦🌈
如果我的分享也能对你有帮助,那就太好了!
若有不足,还请大家多多指正,我们一起学习交流!
📢未来的富豪们:点赞👍→收藏⭐→关注🔍,如果能评论下就太惊喜了!
感谢大家的观看和支持!最后,☺祝愿大家每天有钱赚!!!欢迎关注、关注!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1702041.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

新窃密软件 NodeStealer 可以窃取所有浏览器 Cookie

Netskope 的研究人员正在跟踪一个使用恶意 Python 脚本窃取 Facebook 用户凭据与浏览器数据的攻击行动。攻击针对 Facebook 企业账户,包含虚假 Facebook 消息并带有恶意文件。攻击的受害者主要集中在南欧与北美,以制造业和技术服务行业为主。 2023 年 1 …

二维前缀和[模版]

题目链接 题目: 分析: 求二维数组的区间和问题, 可以使用二维前缀和算法注意: 下标从1开始计数第一步: 预处理出来一个前缀和矩阵dp[i][j] 表示: 从[1,1] 位置到[i,j] 位置, 这段区间里面所有元素的和 dp[i][j] 就等于ABCD, A好求, 就是dp[i-1][j-1], 但BC不好求, 所以我们AB…

D - Permutation Subsequence(AtCoder Beginner Contest 352)

题目链接: D - Permutation Subsequence (atcoder.jp) 题目大意: 分析: 相对于是记录一下每个数的位置 然后再长度为k的区间进行移动 然后看最大的pos和最小的pos的最小值是多少 有点类似于滑动窗口 用到了java里面的 TreeSet和Map TreeSet存的是数…

删除edge浏览器文本框储存记录值以及关闭自动填充

当我们点击输入框时总会出现许多以前输入过的信息。 一、删除edge浏览器文本框储存记录值 1、首先按下↓键选中你想删除的信息 二、关闭自动填充。 1、在地址栏输入edge://wallet/settings跳转到以下界面 2、往下滑找到 全部取消即可

区块链技术引领:Web3时代的新网络革命

随着区块链技术的快速发展和不断成熟,人们已经开始意识到它所带来的潜在影响,尤其是在构建一个更加去中心化、安全和透明的互联网时。这个新的互联网时代被称为Web3,它将不再受制于传统的中心化平台,而是更多地依赖于去中心化的网…

水电站泄水预警广播系统解决方案

一、背景 随着水电站不断发展,泄水预警成为确保周边地区安全的重要环节。为了有效应对泄水事件,提高预警效率,本方案提出建设水电站泄水预警广播系统。该系统通过先进的技术手段,实现快速、准确的泄水预警信息传达,保…

产品推荐-光学镜片镀膜自动上下料设备

随着现代化工业生产的浪潮,智能化和自动化已成为工业发展的必然趋势。在精密制造领域,高精度和高效率更是工艺流程中不可或缺的要素。为满足这一需求,富唯推出了引领行业潮流的智能设备——富唯智能镀膜上下料设备。 一、多功能操作&#xff…

spring suite gitlab使用手册

一、gitlab介绍 GitLab是一个功能丰富的开源代码管理平台,基于Git进行版本控制,并提供了一系列用于团队协作、项目管理、持续集成/持续部署(CI/CD)等工具。以下是关于GitLab的详细介绍: 基础信息: GitLab…

Text Control 控件 中 Service Pack 3:MailMerge 支持 SVG 图像

图像的合并方式与报告模板中的合并字段相同。占位符在设计时添加,并与文件、数据库或内存中的数据合并。可以将图像对象添加到具有指定名称的模板中。数据列必须包含字节数组形式的二进制图像数据、System.Drawing.Image 类型的对象、文件名、十六进制或 Base64 编码…

Android Compose 七:常用组件 Image

1 基本使用 Image(painter painterResource(id R.drawable.ic_wang_lufei), contentDescription "" ) // 图片Spacer(modifier Modifier.height(20.dp))Image(imageVector ImageVector.vectorResource(id R.drawable.ic_android_black_24dp), contentDescript…

Python中的SSH、SFTP和FTP操作详解

大家好,在网络编程中,安全地连接到远程服务器并执行操作是一项常见任务。Python 提供了多种库来实现这一目标,其中 Paramiko 是一个功能强大的工具,可以轻松地在 Python 中执行 SSH、SFTP 和 FTP 操作。本文将介绍如何使用 Parami…

动态规划-二维费用的背包问题

文章目录 1. 一和零(474)2. 盈利计划(879) 1. 一和零(474) 题目描述: 状态表示: 我们之前的01背包问题以及完全背包问题都是一维的,因为我们只有一个要求或者说是限制那…

IT技术 | 电脑蓝屏修复记录DRIVER_IRQL_NOT_LESS_OR_EQUAL

我的台式机是iMac 2015年的,硬盘是机械的,时间久了运行越来越慢。后来对苹果系统失去了兴趣,想换回windows,且想换固态硬盘,就使用winToGo 搞了双系统,在USB外接移动固态硬盘上安装了win10系统。 最近&…

安卓六种页面加载优化方案对比总结

根据工作经验,笔者提炼了六种页面加载优化方式,按照业务与非业务,将六种加载方式分为两类: 业务类 控制业务与UI的执行顺序、控制多业务之间的执行顺序 ①预加载:是指在进入页面之前,提前获得页面所需得数据…

2024经济管理、社会科学与教育国际会议(ICEMSSE 2024)

2024经济管理、社会科学与教育国际会议(ICEMSSE 2024) 会议简介 2024年国际经济管理、社会科学和教育会议(ICEMSSE 2024)专注于经济、社会发展和教育。会议旨在为专家、学者和社会人士提供一个交流平台。通过讨论科学研究成果和前沿技术,我…

浅谈配置元件之CSV 数据文件设置

浅谈配置元件之CSV 数据文件设置 为了增强测试的真实性和多样性,JMeter 提供了多种数据参数化的方式,其中 CSV 数据文件设置(CSV Data Set Config)是一种常用且强大的功能,它允许测试脚本从外部CSV文件中动态读取数据…

基于springboot+vue的“漫画之家”系统

开发语言:Java框架:springbootJDK版本:JDK1.8服务器:tomcat7数据库:mysql 5.7(一定要5.7版本)数据库工具:Navicat11开发软件:eclipse/myeclipse/ideaMaven包:…

Samtec mPower®电源连接器:高能、可靠、灵活、小巧

【摘要/前言】 电源连接器是互连解决方案中不可或缺的一个组成部分。虽然相较于比较爱“竞速”的信号连接器,电源连接器的技术迭代不是那么频繁,但是它是连接电源和用电设备的重要纽带,想要确保设备正常运行,就少不了它的身影。 …

EPSON RX8111CE+松下高性能电池的组合应用

RTC是一种实时时钟,用于记录和跟踪时间,具有独立供电和时钟功能。在某些应用场景中,为了保证RTC在断电或者其他异常情况下依然能够正常工作,需要备份电池方案来提供稳定的供电。本文将介绍EPSON爱普生nA级RTC RX8111CE松下Panason…