3-5 提高模型效果:归一化

news2025/1/9 19:43:20

3-5 提高模型效果:归一化

主目录点这里
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
举例

1. 批量归一化 (Batch Normalization, BN)

应用场景: 通常用于图像分类任务,它在训练期间对每个批次的数据进行归一化,以加速收敛并稳定训练过程。

代码示例:

import torch
import torch.nn as nn
import torch.optim as optim

# 定义一个简单的卷积神经网络使用批量归一化
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3)
        self.bn1 = nn.BatchNorm2d(32)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3)
        self.bn2 = nn.BatchNorm2d(64)
        self.fc1 = nn.Linear(64 * 6 * 6, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.bn1(self.conv1(x))
        x = nn.functional.relu(x)
        x = self.bn2(self.conv2(x))
        x = nn.functional.relu(x)
        x = nn.functional.max_pool2d(x, 2)
        x = x.view(-1, 64 * 6 * 6)
        x = nn.functional.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# 初始化模型
model = SimpleCNN()

# 打印模型结构
print(model)

在这里插入图片描述

2. 层归一化 (Layer Normalization, LN)

应用场景: 通常用于文本分类任务或RNN/LSTM网络中,它对每个样本的每一层的所有神经元进行归一化,而不是每个批次的数据。

代码示例:

import torch
import torch.nn as nn

# 定义一个简单的全连接网络使用层归一化
class SimpleFC(nn.Module):
    def __init__(self):
        super(SimpleFC, self).__init__()
        self.fc1 = nn.Linear(20, 50)
        self.ln1 = nn.LayerNorm(50)
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x):
        x = self.ln1(self.fc1(x))
        x = nn.functional.relu(x)
        x = self.fc2(x)
        return x

# 初始化模型
model = SimpleFC()

# 打印模型结构
print(model)

在这里插入图片描述

3. 实例归一化 (Instance Normalization, IN)

应用场景: 通常用于生成对抗网络(GAN)和风格迁移任务,它对每个样本的每个通道独立进行归一化,适合处理不同样本之间差异较大的情况。

代码示例:

import torch
import torch.nn as nn

# 定义一个简单的卷积神经网络使用实例归一化
class SimpleINN(nn.Module):
    def __init__(self):
        super(SimpleINN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3)
        self.in1 = nn.InstanceNorm2d(32)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3)
        self.in2 = nn.InstanceNorm2d(64)
        self.fc1 = nn.Linear(64 * 6 * 6, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.in1(self.conv1(x))
        x = nn.functional.relu(x)
        x = self.in2(self.conv2(x))
        x = nn.functional.relu(x)
        x = nn.functional.max_pool2d(x, 2)
        x = x.view(-1, 64 * 6 * 6)
        x = nn.functional.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# 初始化模型
model = SimpleINN()

# 打印模型结构
print(model)

在这里插入图片描述

4. 组归一化 (Group Normalization, GN)

应用场景: 通常用于小批量数据或训练时批量大小非常小的情况。它将通道划分为组,并对每组内的通道进行归一化。

代码示例:

import torch
import torch.nn as nn

# 定义一个简单的卷积神经网络使用组归一化
class SimpleGNN(nn.Module):
    def __init__(self):
        super(SimpleGNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3)
        self.gn1 = nn.GroupNorm(4, 32)  # 32个通道分成4组,每组8个通道
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3)
        self.gn2 = nn.GroupNorm(8, 64)  # 64个通道分成8组,每组8个通道
        self.fc1 = nn.Linear(64 * 6 * 6, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.gn1(self.conv1(x))
        x = nn.functional.relu(x)
        x = self.gn2(self.conv2(x))
        x = nn.functional.relu(x)
        x = nn.functional.max_pool2d(x, 2)
        x = x.view(-1, 64 * 6 * 6)
        x = nn.functional.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# 初始化模型
model = SimpleGNN()

# 打印模型结构
print(model)

在这里插入图片描述

后话

这些归一化技术各有特点和应用场景,选择适合的归一化方法可以显著提升模型的训练效果和性能。为了更加体现归一化在模型训练中的效果,我们可以选择一个更复杂的任务和网络结构,比如在 CIFAR-10 数据集上训练一个较深的卷积神经网络(CNN),并比较是否使用归一化的模型的训练和验证性能。

CIFAR-10 是一个包含 10 类的彩色图像数据集。我们将训练两个深度卷积神经网络,一个使用批量归一化(Batch Normalization, BN),另一个不使用。我们将展示它们在训练速度、收敛性和最终精度上的差异。

# -*- coding: UTF-8 -*-
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split


# 定义不使用批量归一化的深度卷积神经网络模型
class DeepCNN_NoBN(nn.Module):
    def __init__(self):
        super(DeepCNN_NoBN, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
        self.fc1 = nn.Linear(256 * 4* 4, 512)
        self.fc2 = nn.Linear(512, 10)

    def forward(self, x):
        x = nn.functional.relu(self.conv1(x))
        x = nn.functional.max_pool2d(x, 2)
        x = nn.functional.relu(self.conv2(x))
        x = nn.functional.max_pool2d(x, 2)
        x = nn.functional.relu(self.conv3(x))
        x = nn.functional.max_pool2d(x, 2)
        x = x.view(-1, 256 * 4 * 4)
        x = nn.functional.relu(self.fc1(x))
        x = self.fc2(x)
        return x


# 定义使用批量归一化的深度卷积神经网络模型
class DeepCNN_BN(nn.Module):
    def __init__(self):
        super(DeepCNN_BN, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(64)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(128)
        self.conv3 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(256)
        self.fc1 = nn.Linear(256 * 4 * 4, 512)
        self.fc2 = nn.Linear(512, 10)

    def forward(self, x):
        x = nn.functional.relu(self.bn1(self.conv1(x)))
        x = nn.functional.max_pool2d(x, 2)
        x = nn.functional.relu(self.bn2(self.conv2(x)))
        x = nn.functional.max_pool2d(x, 2)
        x = nn.functional.relu(self.bn3(self.conv3(x)))
        x = nn.functional.max_pool2d(x, 2)
        x = x.view(-1, 256 * 4 * 4)
        x = nn.functional.relu(self.fc1(x))
        x = self.fc2(x)
        return x


# 训练和验证函数
def train_and_validate(model, train_loader, val_loader, epochs=20):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    for epoch in range(epochs):
        model.train()
        train_loss, correct = 0, 0

        for inputs, targets in train_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
            _, predicted = outputs.max(1)
            correct += predicted.eq(targets).sum().item()

        train_loss /= len(train_loader)
        train_accuracy = 100. * correct / len(train_loader.dataset)

        model.eval()
        val_loss, correct = 0, 0

        with torch.no_grad():
            for inputs, targets in val_loader:
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                val_loss += loss.item()
                _, predicted = outputs.max(1)
                correct += predicted.eq(targets).sum().item()

        val_loss /= len(val_loader)
        val_accuracy = 100. * correct / len(val_loader.dataset)

        print(f"Epoch: {epoch + 1}, Train Loss: {train_loss:.4f}, Train Acc: {train_accuracy:.2f}%, "
              f"Val Loss: {val_loss:.4f}, Val Acc: {val_accuracy:.2f}%")


def main():
    # CIFAR-10 数据集预处理
    transform = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.RandomCrop(32, padding=4),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
    ])

    train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
    test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

    # 分割训练集为训练集和验证集
    train_size = int(0.8 * len(train_dataset))
    val_size = len(train_dataset) - train_size
    train_dataset, val_dataset = random_split(train_dataset, [train_size, val_size])

    # 数据加载器,设置 num_workers 为 0 避免多进程问题
    train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=0)
    val_loader = DataLoader(val_dataset, batch_size=128, shuffle=False, num_workers=0)
    test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=0)

    # 初始化模型
    model_no_bn = DeepCNN_NoBN()
    model_bn = DeepCNN_BN()

    print("Training model without Batch Normalization")
    train_and_validate(model_no_bn, train_loader, val_loader)

    print("\nTraining model with Batch Normalization")
    train_and_validate(model_bn, train_loader, val_loader)


if __name__ == '__main__':
    main()



在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1905418.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【实践分享】深度学习远程连接GPU

目录 前言 一、创建实例 二、上传文件 三、服务器上传 四、运行代码文件 前言 1、使用平台:恒源云 2、教程总结自B站大佬Larry同学发布的教程视频 一、创建实例 通俗:租用一台临时的电脑,电脑可自选GPU型号等,按照项目需…

Linux基础:一. 简单的命令

文章目录 一. 简单的命令1.1 关机1.2 重启1.3 控制台打印工作目录1.4 切换当前目录1.5 列出当前目录中的目录和文件1.6 列出指定目录中的目录和文件1.7 控制台清屏1.8 查看和设置时间1.8.1 查看时间1.8.2 设置时间,需要管理员权限 一. 简单的命令 1.1 关机 comman…

FairJob:促进在线广告系统公平性研究

在人工智能(AI)与人类动态的交汇处,既存在机遇也存在挑战,特别是在人工智能领域。尽管取得了进步,但根植于历史不平等中的持续偏见仍然渗透在我们的数据驱动系统中,这些偏见不仅延续了不公平现象&#xff0…

PingCAP 成为全球数据库管理系统市场增速最快的厂商

近日,Gartner 发布的《Market Share Analysis: Database Management Systems, Worldwide, 2023》(2024 年 6 月)报告显示:“2023 年全球数据库管理系统(DBMS)市场的增长率为 13.4%,略低于去年的…

排序 -- 计数排序以及对排序的总结

到了这篇文章就说明常见的排序我们就快要讲完了,那这篇文章我们就讲一下非比较排序--计数排序。 一、非比较排序 1.基本思想 计数排序又称为鸽巢原理,是对哈希直接定址法的变形应用。 操作步骤: 统计相同元素出现次数 根据统计的结果将序列…

LaTeX教程(014)-LaTeX文档结构(14)

LaTeX教程(014)- LaTeX \LaTeX LATE​X文档结构(14) 2.3.3 multitoc - 将目录设置为多栏 multitoc包的使用方法相当简单,只需要调用这个包,并将要设置为多栏(默认是双栏)的目录指定到包选项中即可。如\usepackage[toc]{multitoc},设置的就是…

25_嵌入式系统总线接口

目录 串行接口基本原理 串行通信 串行数据传送模式 串行通信方式 RS-232串行接口 RS-422串行接口 RS-485串行接口 RS串行总线总结 RapidIO高速串行总线 ARINC429总线 并行接口基本原理 并行通信 IEEE488总线 SCSI总线 MXI总线 PCI接口基本原理 PCI总线原理 PC…

jmeter-beanshell学习4-beanshell截取字符串

再写个简单点的东西,截取字符串,参数化文件统一用csv,然后还要用excel打开,如果是数字很容易格式就乱了。有同事是用双引号把数字引起来,报文里就不用加引号了,但是这样beanshell处理起来,好像容…

MATLAB中的SDPT3、LMILab、SeDuMi工具箱

MATLAB中的SDPT3、LMILab、SeDuMi工具箱都是用于解决特定数学优化问题的工具箱,它们在控制系统设计、机器学习、信号处理等领域有广泛的应用。以下是对这三个工具箱的详细介绍: 1. SDPT3工具箱 简介: SDPT3(Semidefinite Progra…

Nacos服务注册总流程(源码分析)

文章目录 服务注册NacosClient找看源码入口NacosClient服务注册源码NacosServer处理服务注册 服务注册 服务注册 在线流程图 NacosClient找看源码入口 我们启动一个微服务&#xff0c;引入nacos客户端的依赖 <dependency><groupId>com.alibaba.cloud</groupI…

Science Robotics 麻省理工学院最新研究,从仿真中学习的精确选择、定位和抓放物体的视触觉方法

现有的机器人系统在通用性和精确性两个性能目标上难以同时兼顾&#xff0c;往往会陷入一个机器人解决单个任务的情况&#xff0c;缺乏"精确泛化"。本文针对精准和通用的同时兼顾提出了解决方法。提出了SimPLE(Pick Localize和placE的仿真模拟)作为精确拾取和放置的解…

C# 如何获取属性的displayName的3种方式

文章目录 1. 使用特性直接访问2. 使用GetCustomAttribute()方法通过反射获取3. 使用LINQ查询总结和比较 在C#中&#xff0c;获取属性的displayName可以通过多种方式实现&#xff0c;包括使用特性、反射和LINQ。下面我将分别展示每种方法&#xff0c;并提供具体的示例代码。 1.…

【Spring Cloud】一个例程快速了解网关Gateway的使用

Spring Cloud Gateway提供了一个在Spring生态系统之上构建的API网关&#xff0c;包括&#xff1a;Spring 5&#xff0c;Spring Boot 2和Project Reactor。Spring Cloud Gateway旨在提供一种简单而有效的路由方式&#xff0c;并为它们提供一些网关基本功能&#xff0c;例如&…

轻松驾驭开发之旅:Maven配置阿里云CodeUp远程私有仓库全攻略

文章目录 引言一、为什么选择阿里云CodeUp作为远程私有仓库&#xff1f;二、Maven配置阿里云CodeUp远程私有仓库的步骤准备工作配置Maven的settings.xml文件配置项目的pom.xml文件验证配置是否成功 三、使用阿里云CodeUp远程私有仓库的注意事项 引言 在软件开发的世界里&#…

【Linux进程】命令行参数 环境变量(详解)

目录 前言 1. 命令行参数 什么是命令行参数? 2. 环境变量 常见的环境变量 如何修改环境变量? 获取环境变量 环境变量的组织方式 拓展问题 导入环境变量 3. 本地变量* 总结 前言 在使用Linux指令的时候, 都是指令后边根命令行参数, 每个指令本质都是一个一个的可执行程…

数学系C++ 排序算法简述(八)

目录 排序 选择排序 O(n2) 不稳定&#xff1a;48429 归并排序 O(n log n) 稳定 插入排序 O(n2) 堆排序 O(n log n) 希尔排序 O(n log2 n) 图书馆排序 O(n log n) 冒泡排序 O(n2) 优化&#xff1a; 基数排序 O(n k) 快速排序 O(n log n)【分治】 不稳定 桶排序 O(n…

Kaggle网站免费算力使用,深度学习模型训练

声明&#xff1a; 本文主要内容为&#xff1a;kaggle网站数据集上传&#xff0c;训练模型下载、模型部署、提交后台运行等教程。 1、账号注册 此步骤本文略过&#xff0c;如有需要可以参考其他文章。 2、上传资源 不论是上传训练好的模型进行预测&#xff0c;还是训练用的…

2024组装一台能跑AI大模型的电脑

title: 2024组装一台能跑AI大模型的电脑 tags: [组装电脑, AI大模型] categories: [其他, 电脑, windows] 这里不写组装步骤&#xff0c;哪里接线&#xff0c;购买什么品牌网上一大堆。 这里只写如何根据你自己的需求&#xff0c;选择合适的、兼容的配件。 概述 需求&#xff…

区间最值问题-RQM(ST表,线段树)

1.ST表求解 ST表的实质其实是动态规划&#xff0c;下面是区间最小的递归公式&#xff0c;最大只需将min改成max即可 f[i][j] min(f[i][j - 1], f[i (1 << j - 1)][j - 1]); 二维数组的f[i][j]表示从i开始连续2*j个数的最小/大值。 例如&#xff1a;我们给出一个数组…

iOS中多个tableView 嵌套滚动特性探索

嵌套滚动的机制 目前的结构是这样的&#xff0c;整个页面是一个大的tableView, Cell 是整个页面的大小&#xff0c;cell 中嵌套了一个tableView 通过测试我们发现滚动的时候&#xff0c;系统的机制是这样的&#xff0c; 我们滑动内部小的tableView, 开始滑动的时候&#xff0c…