机器学习复习(4)——CNN算法

news2024/9/25 9:31:12

目录

数据增强方法

CNN图像分类数据集构建

导入数据集

定义trainer

超参数设置

数据增强

构建CNN网络

开始训练

模型测试

数据增强方法

# 一般情况下,我们不会在验证集和测试集上做数据扩增
# 我们只需要将图片裁剪成同样的大小并装换成Tensor就行
test_tfm = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
])

# 当然,我们也可以再测试集中对数据进行扩增(对同样本的不同装换)
#  - 用训练数据的装化方法(train_tfm)去对测试集数据进行转化,产出扩增样本
#  - 对同个照片的不同样本分别进行预测
#  - 最后可以用soft vote / hard vote 等集成方法输出最后的预测
train_tfm = transforms.Compose([
    # 图片裁剪 (height = width = 128)
    transforms.Resize((128, 128)),
    transforms.AutoAugment(transforms.AutoAugmentPolicy.IMAGENET),
    # ToTensor() 放在所有处理的最后
    transforms.ToTensor(),
])

CNN图像分类数据集构建

class FoodDataset(Dataset):
    # 构造函数
    def __init__(self, path, tfm=test_tfm, files=None):
        # 调用父类的构造函数
        super(FoodDataset).__init__()
        # 存储图像文件夹路径
        self.path = path
        # 从路径中获取所有以.jpg结尾的文件,并按字典顺序排序
        self.files = sorted([os.path.join(path, x) for x in os.listdir(path) if x.endswith(".jpg")])
        # 如果提供了文件列表,则使用该列表代替自动搜索得到的列表
        if files is not None:
            self.files = files
        # 打印路径中的一个样本文件路径
        print(f"One {path} sample", self.files[0])
        # 存储用于图像变换的函数
        self.transform = tfm
    # 返回数据集中的样本数
    def __len__(self):
        return len(self.files)
    # 根据索引获取单个样本
    def __getitem__(self, idx):
        # 获取文件名
        fname = self.files[idx]
        # 打开图像文件
        im = Image.open(fname)
        # 应用变换
        im = self.transform(im)
        # 尝试从文件名中提取标签,如果失败则设置为-1(表示测试集中没有标签)
        try:
            label = int(fname.split("/")[-1].split("_")[0])
        except:
            label = -1  # 测试集没有label
        # 返回图像和标签
        return im, label

导入数据集

注意这里的“私有方法”

_dataset_dir = config['dataset_dir']#“_”是为了避免和python中的dataset重名

train_set = FoodDataset(os.path.join(_dataset_dir,"training"), tfm=train_tfm)
train_loader = DataLoader(train_set, batch_size=config['batch_size'], 
                          shuffle=True, num_workers=0, pin_memory=True)

valid_set = FoodDataset(os.path.join(_dataset_dir,"validation"), tfm=test_tfm)
valid_loader = DataLoader(valid_set, batch_size=config['batch_size'], 
                          shuffle=True, num_workers=0, pin_memory=True)

# 测试级保证输出顺序一致
test_set = FoodDataset(os.path.join(_dataset_dir,"test"), tfm=test_tfm)
test_loader = DataLoader(test_set, batch_size=config['batch_size'], 
                         shuffle=False, num_workers=0, pin_memory=True)

定义trainer

def trainer(train_loader, valid_loader, model, config, device, rest_net_flag=False):
    # 定义交叉熵损失函数,用于评估分类任务的模型性能
    criterion = nn.CrossEntropyLoss()

    # 初始化优化器,这里使用Adam优化器
    optimizer = torch.optim.Adam(model.parameters(), lr=config['learning_rate'], weight_decay=config['weight_decay'])

    # 根据rest_net_flag标志选择模型保存路径
    save_path = config['save_path'] if rest_net_flag else config['resnet_save_path']

    # 初始化TensorBoard的SummaryWriter,用于记录训练过程
    writer = SummaryWriter()

    # 如果'models'目录不存在,则创建该目录
    if not os.path.isdir('./models'):
        os.mkdir('./models')

    # 初始化训练参数:训练轮数、最佳损失、步骤计数器和早停计数器
    n_epochs, best_loss, step, early_stop_count = config['n_epochs'], math.inf, 0, 0

    # 进行多个训练周期
    for epoch in range(n_epochs):
        # 设置模型为训练模式
        model.train()

        # 初始化损失记录器和准确率记录器
        loss_record = []
        train_accs = []

        # 使用tqdm显示训练进度条
        train_pbar = tqdm(train_loader, position=0, leave=True)

        # 遍历训练数据
        for x, y in train_pbar:
            # 重置优化器梯度
            optimizer.zero_grad()

            # 将数据和标签移动到指定设备(如GPU)
            x, y = x.to(device), y.to(device)

            # 进行一次前向传播
            pred = model(x)

            # 计算损失
            loss = criterion(pred, y)

            # 反向传播
            loss.backward()

            # 如果启用梯度裁剪,则应用梯度裁剪
            if config['clip_flag']:
                grad_norm = nn.utils.clip_grad_norm_(model.parameters(), max_norm=10)

            # 进行一步优化(梯度下降)
            optimizer.step()

            # 记录当前步骤
            step += 1

            # 计算准确率并记录损失和准确率
            acc = (pred.argmax(dim=-1) == y.to(device)).float().mean()
            l_ = loss.detach().item()
            loss_record.append(l_)
            train_accs.append(acc.detach().item())
            train_pbar.set_description(f'Epoch [{epoch+1}/{n_epochs}]')
            train_pbar.set_postfix({'loss': f'{l_:.5f}', 'acc': f'{acc:.5f}'})

        # 计算并记录平均训练损失和准确率
        mean_train_acc = sum(train_accs) / len(train_accs)
        mean_train_loss = sum(loss_record) / len(loss_record)
        writer.add_scalar('Loss/train', mean_train_loss, step)
        writer.add_scalar('ACC/train', mean_train_acc, step)

        # 设置模型为评估模式
        model.eval()

        # 初始化验证集损失记录器和准确率记录器
        loss_record = []
        test_accs = []

        # 遍历验证数据
        for x, y in valid_loader:
            x, y = x.to(device), y.to(device)
            with torch.no_grad():
                pred = model(x)
                loss = criterion(pred, y)
                acc = (pred.argmax(dim=-1) == y.to(device)).float().mean()

            loss_record.append(loss.item())
            test_accs.append(acc.detach().item())

        # 计算并打印平均验证损失和准确率
        mean_valid_acc = sum(test_accs) / len(test_accs)
        mean_valid_loss = sum(loss_record) / len(loss_record)
        print(f'Epoch [{epoch+1}/{n_epochs}]: Train loss: {mean_train_loss:.4f}, acc: {mean_train_acc:.4f} Valid loss: {mean_valid_loss:.4f}, acc: {mean

超参数设置

device = 'cuda' if torch.cuda.is_available() else 'cpu'
config = {
    'seed': 6666,
    'dataset_dir': "../input/data",
    'n_epochs': 10,      
    'batch_size': 64, 
    'learning_rate': 0.0003,           
    'weight_decay':1e-5,
    'early_stop': 300,
    'clip_flag': True, 
    'save_path': './models/model.ckpt',
    'resnet_save_path': './models/resnet_model.ckpt'
}
print(device)
all_seed(config['seed'])

数据增强

test_set = FoodDataset(os.path.join(_dataset_dir,"test"), tfm=train_tfm)
test_loader_extra1 = DataLoader(test_set, batch_size=config['batch_size'], 
                                shuffle=False, num_workers=0, pin_memory=True)

test_set = FoodDataset(os.path.join(_dataset_dir,"test"), tfm=train_tfm)
test_loader_extra2 = DataLoader(test_set, batch_size=config['batch_size'], 
                                shuffle=False, num_workers=0, pin_memory=True)

test_set = FoodDataset(os.path.join(_dataset_dir,"test"), tfm=train_tfm)
test_loader_extra3 = DataLoader(test_set, batch_size=config['batch_size'], 
                                shuffle=False, num_workers=0, pin_memory=True)

构建CNN网络

class Classifier(nn.Module):
    def __init__(self):
        super(Classifier, self).__init__()
        # input 維度 [3, 128, 128]
        self.cnn = nn.Sequential(
            nn.Conv2d(3, 64, 3, 1, 1),  # [64, 128, 128]
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2, 2, 0),      # [64, 64, 64]

            nn.Conv2d(64, 128, 3, 1, 1), # [128, 64, 64]
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(2, 2, 0),      # [128, 32, 32]

            nn.Conv2d(128, 256, 3, 1, 1), # [256, 32, 32]
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(2, 2, 0),      # [256, 16, 16]

            nn.Conv2d(256, 512, 3, 1, 1), # [512, 16, 16]
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.MaxPool2d(2, 2, 0),       # [512, 8, 8]
            
            nn.Conv2d(512, 512, 3, 1, 1), # [512, 8, 8]
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.MaxPool2d(2, 2, 0),       # [512, 4, 4]
        )
        self.fc = nn.Sequential(
            nn.Linear(512*4*4, 1024),
            nn.ReLU(),
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Linear(512, 11)
        )

    def forward(self, x):
        out = self.cnn(x)
        out = out.view(out.size()[0], -1)
        return self.fc(out)

举一个具体的例子来解释: out = out.view(out.size()[0], -1)

假设我们有一个4维的张量 out,其维度是 [10, 3, 32, 32]。这个张量可以被理解为一个小批量(batch)的图像数据,其中:

  • 10 是批处理大小(batch size),表示有10个图像。
  • 3 是通道数(channels),例如在RGB图像中有3个颜色通道。
  • 3232 是图像的高度和宽度。

现在,我们想将这个4维张量转换为2维张量,以便它可以被用作全连接层(dense layer)的输入。这就是 out.view(out.size()[0], -1) 用途所在。

执行这个操作后,张量的形状将会是:

  • 第一个维度仍然是10,这保持了批处理大小不变。
  • 第二个维度是由-1指定的,这让PyTorch自动计算这个维度的大小。在我们的例子中,其余的维度(3, 32, 32)将被展平,所以第二个维度的大小是 3 * 32 * 32 = 3072。

因此,执行 out = out.view(out.size()[0], -1) 后,out 的形状将会从 [10, 3, 32, 32] 变为 [10, 3072]。这个新的二维张量可以被看作是一个包含10个样本的数据批次,每个样本都被展平为3072个特征的一维数组。这种形状的张量适合作为全连接层的输入。

1. Conv2d(卷积层)

卷积层的输出尺寸可以用以下公式计算:

其中:

  • 输入尺寸是输入特征图的高度或宽度。
  • 卷积核尺寸是卷积核的高度或宽度。
  • 填充(Padding)是在输入特征图周围添加的零的层数。
  • 步长(Stride)是卷积核移动的步幅。

2. MaxPool2d(最大池化层)

最大池化层的输出尺寸可以用类似的公式计算:

对于最大池化,通常不使用填充

 假设我们有一个大小为[32, 32](高度32,宽度32)的输入特征图,并且我们想应用以下两个层:

  1. Conv2d层,卷积核大小为[3, 3],步长为1,填充为1
  2. MaxPool2d层,池化核大小为[2, 2],步长为2

对于Conv2d层,输出尺寸计算如下:

对于MaxPool2d层,输出尺寸计算如下:

所以,经过这两层处理后,最终输出的特征图尺寸将会是[16, 16]

开始训练

model = Classifier().to(device)
trainer(train_loader, valid_loader, model, config, device)

或者可以通过调用pytorch官方的一些标准model进行训

from torchvision.models import resnet50
resNet = resnet50(pretrained=False)
# 残差网络
resNet = resNet.to(device)
trainer(train_loader, valid_loader, resNet, config, device)

模型测试

model_best = Classifier().to(device)
model_best.load_state_dict(torch.load(config['save_path']))
model_best.eval()
prediction = []
with torch.no_grad():
    for data,_ in test_loader:
        test_pred = model_best(data.to(device))
        test_label = np.argmax(test_pred.cpu().data.numpy(), axis=1)
        prediction += test_label.squeeze().tolist()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1430133.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

跨平台开发:浅析uni-app及其他主流APP开发方式

随着智能手机的普及,移动应用程序(APP)的需求不断增长。开发一款优秀的APP,不仅需要考虑功能和用户体验,还需要选择一种适合的开发方式。随着技术的发展,目前有多种主流的APP开发方式可供选择,其…

【计网·湖科大·思科】实验七 路由信息协议RIP、开放最短路径优先协议OSPF、边界网关协议BGP

🕺作者: 主页 我的专栏C语言从0到1探秘C数据结构从0到1探秘Linux 😘欢迎关注:👍点赞🙌收藏✍️留言 🏇码字不易,你的👍点赞🙌收藏❤️关注对我真的很重要&…

nginx初学者指南

一、启动、停止和重新加载配置 前提:先要启动nginx 在Windows上启动nginx的步骤如下: 1. 下载并安装nginx。可以从nginx官网下载适合自己操作系统的版本,一般是zip压缩包,解压到指定目录中。 2. 进入nginx的安装目录&#xff…

简单几步,借助Aapose.Cells将 Excel 工作表拆分为文件

近年来,Excel 文件已成为无数企业数据管理的支柱。然而,管理大型 Excel 文件可能是一项艰巨的任务,尤其是在高效共享和处理数据时。为了应对这一挑战,大型 Excel 工作簿被拆分为较小的工作簿以增强电子表格管理。Aspose提供了这样…

electron项目在内网环境的linux环境下进行打包

Linux需要的文件: electron-v13.0.0-linux-x64.zip appimage-12.0.1.7z snap-template-electron-4.0-1-amd64.tar.7z 下载慢或者下载失败的情况可以手动下载以上electron文件复制到指定文件夹下: 1.electron-v13.0.0-linux-x64.zip 复制到~/.cache/electron/目录下…

Blender使用Rigify和Game Rig Tool基础

做动画需要的几个简要步骤: 1.建模 2.绑定骨骼 3.绘制权重 4.动画 有一个免费的插件可以处理好给引擎用:Game Rig Tool 3.6和4.0版本的 百度网盘 提取码:vju8 1.Rigify是干嘛用的? 》 绑定骨骼 2.Game Rig Tool干嘛用的&#xf…

LVGL部件8

一.按钮矩阵部件 1.知识概览 2.函数接口 1.lv_btnmatrix_set_btn_ctrl 在 LVGL(LittlevGL)中,lv_btnmatrix_set_btn_ctrl() 函数用于设置按钮矩阵(Button Matrix)中单个按钮的控制选项。该函数可以用来定制按钮矩阵中…

寒假作业2月3号

第二章 引用内联重载 一.选择题 1、适宜采用inline定义函数情况是(C) A. 函数体含有循环语句 B. 函数体含有递归语句 C. 函数代码少、频繁调用 D. 函数代码多、不常调用 2、假定一个函数为A(int i4, int j0) {;}, 则执行“A (1);”语句…

【蓝桥杯】环形链表的约瑟夫问题

目录 题目描述: 输入描述: 输出描述: 示例1 解法一(C): 解法二(Cpp): 正文开始: 题目描述: 据说著名犹太历史学家 Josephus 有过以下故事&a…

UE4 C++ 枚举类型

先在UCLASS()前写入: //定义枚举变量:方法一 UENUM(BlueprintType) //BlueprintType:在蓝图中可显示、创建该枚举变量 namespace MyEnumType //namespace:命名空间,支持同样的变量命令、便于访问//MyEnumType&#xf…

如何保证MySQL和Redis中的数据一致性?

文章目录 前言一、缓存案例1.1 缓存常见用法1.2 缓存不一致产生的原因 二、解决方案2.1 先删除缓存,再更新数据库2.2 先更新数据库,删除缓存2.3 只更新缓存,由缓存自己同步更新数据库2.4 只更新缓存,由缓存自己异步更新数据库2.5 …

【MybatisPlus篇】查询条件设置(范围匹配 | 模糊匹配 | 空判定 | 包含性判定 | 分组 | 排序)

文章目录 🎄环境准备⭐导入依赖⭐写入User类⭐配置启动类⭐创建UserDao 的 MyBatis Mapper 接口,用于定义数据库访问操作⭐创建配置文件🛸创建测试类MpATest.java 🍔范围查询⭐eq⭐between⭐gt 🍔模糊匹配⭐like &…

力扣之2629.复合函数(reduceRight )

/*** param {Function[]} functions* return {Function}*/ var compose function(functions) {return function(x) {return functions.reduceRight((result, func) > func(result), x);} };/*** const fn compose([x > x 1, x > 2 * x])* fn(4) // 9*/ 说明&#x…

docker 容器指定主机网段

docker 容器指定主机网段。 直接连接到物理网络:使用macvlan技术可以让Docker容器直接连接到物理网络,而不需要通过NAT或端口映射的方式来访问它们。可以提高网络性能和稳定性,同时也可以使容器更加透明和易于管理。 1、查询网卡的名称&…

微软Office Plus与WPS Office的较量:办公软件市场将迎来巨变?

微软Office Plus在功能表现上远超WPS Office? 微软出品的Office套件实力强劲,其不仅在办公场景中扮演着不可或缺的角色,为用户带来高效便捷的体验,而且在娱乐生活管理等多元领域中同样展现出了卓越的应用价值 作为中国本土办公软…

c语言--求第n个斐波那契数列(递归、迭代)

目录 一、概念二、用迭代求第n个斐波那契数1.分析2.完整代码3.运行结果4.如果求第50个斐波那契数呢?看看会怎么样。4.1运行结果:4.2画图解释 三、用迭代的方式求第n个斐波那契数列1.分析2.完整代码3.运行结果4.求第50个斐波那契数4.1运行结果4.2运行结果…

基于粒子群算法的多无人机任务分配

python3.6以上正常运行 基于粒子群算法的多无人机任务分配资源-CSDN文库

在Flutter中调用Android的代码

参考 【Flutter 混合开发】嵌入原生View-Android 默认使用Android studio 和 Kotlin 基本配置 创建flutter项目 在终端执行 flutter create batterylevel添加 Android 平台的实现 打开项目下的android/app/src/main/kotlin 下的 MainActivity.kt 文件。 我这里编辑器有…

开源的三维算法库有哪些

PCL,VTK,VCG,CGAL,Open CASCADE(opencascade),OpenSceneGraph (OSG),Easy3D 点云网格处理算法:openmesh, meshlab三维算法库,Eigen 网格简化,网格平滑,网格参数化 无序…

北朝隋唐文物展亮相广西,文物预防性保护网关保驾护航

一、霸府名都——太原博物馆收藏北朝隋朝文物展 2月1日,广西民族博物馆与太原博物馆携手,盛大开启“霸府名都——太原博物馆北朝隋文物展”。此次新春展览精选了北朝隋唐时期150多件晋阳文物珍品。依据“巍巍雄镇”“惊世古冢”“锦绣名都”三个单元&am…