Kaggle系列之CIFAR-10图像识别分类(残差网络模型ResNet-18)

news2025/1/22 12:55:58

CIFAR-10数据集在计算机视觉领域是一个很重要的数据集,很有必要去熟悉它,我们来到Kaggle站点,进入到比赛页面:https://www.kaggle.com/competitions/cifar-10

CIFAR-10是8000万小图像数据集的一个子集,由60000张32x32彩色图像组成,包含10个分类,每个类有6000张图像。

官方数据中有5万张训练图片和1万张测试图片。我们保留了原始数据集中的训练/测试分割

在Kaggle比赛提交的时候,为了阻止某些形式的作弊(比如手标),我们在测试集中添加了29万张垃圾图片。这些图像在评分时被忽略。我们还对官方的10000个测试图像进行了微小的修改,以防止通过文件散列查找它们。这些修改不应明显影响得分。您应该预测所有30万张图像的标签。

对于刷排行榜这些我们不用去管,秉持着学习为主的想法,我们来训练这个数据集。分成10个类别,分别为:airplane、automobile、bird、cat、deer、dog、frog、horse、ship、truck

这些类是完全相互排斥的,比如汽车和卡车之间没有重叠。automobile:包括轿车、suv之类的东西。truck:只包括大卡车。这两项都不包括皮卡车。

将下载的数据集放入到dataset目录,解压之后,在这个目录下面有train目录、test目录、trainLabels.csv标签文件,其中train里面是5万张图片、test里面是30万张图片

1、整理原始数据集

1.1读取训练集的标签文件

def read_label_file(data_dir,label_file,train_dir,valid_ratio):
    '''
    读取训练集的标签文件
        参数
            valid_ratio:验证集样本数与原始训练集样本数之比
        返回值
            n_train // len(labels):每个类多少张图片
            idx_label:50000个id:label的字典
    '''
    with open(os.path.join(data_dir,label_file),'r') as f:
        lines=f.readlines()[1:]
        tokens=[l.rstrip().split(',') for l in lines]
        idx_label=dict(((int(idx),label) for idx,label in tokens))
    #{'cat', 'ship', 'frog', 'dog', 'truck', 'deer', 'horse', 'bird', 'airplane', 'automobile'}
    labels=set(idx_label.values())#去重就是10个类别
    n_train_valid=len(os.listdir(os.path.join(data_dir,train_dir)))#50000
    n_train=int(n_train_valid*(1-valid_ratio))
    assert 0<n_train<n_train_valid
    return n_train // len(labels),idx_label

我们测试下,先熟悉下这个方法:

data_dir,label_file="dataset","trainLabels.csv"
train_dir,valid_ratio="train",0.1

n_train_per_label,idx_label=read_label_file(data_dir,label_file,train_dir,valid_ratio)
print(n_train_per_label,idx_label)#4500,{id:label,...}

读取标签文件,返回每个类有多少个训练样本(id:label这样的id对应标签的字典)

1.2切分验证数据集

上面读取标签的方法中有参数"valid_ratio",用来从原始训练集中切分出验证集,这里设定为0.1

接下来我们将切分的45000张图片用于训练,5000张图片用于验证,将它们分别存放到input_dir/train,input_dir/valid目录下面,这里的input_dir,我这里设置为train_valid_test,在train目录下面是10个分类的目录(这个将定义一个方法自动创建),每个分类目录里面是4500张所属类别的图片;在valid目录下面也是10个分类的目录(同样自动创建),每个分类目录里面是500张所属类别的图片;还有一个train_valid目录,下面同样是10个分类目录,每个类别目录包含5000张图片。

本人的路径如下:

D:\CIFAR10\ dataset\train_valid_test\train\[airplane...]\[1-4500].png
D:\CIFAR10\ dataset\train_valid_test\valid\[automobile...]\[1-500].png
D:\CIFAR10\ dataset\train_valid_test\train_valid\[bird...]\[1-5000].png

这里定义一个辅助函数,新建不存在的路径,将递归新建目录:

#辅助函数,路径不存在就创建
def mkdir_if_not_exist(path):
    if not os.path.exists(os.path.join(*path)):
        os.makedirs(os.path.join(*path))


def reorg_train_valid(data_dir,train_dir,input_dir,n_train_per_label,idx_label):
    '''
    切分训练数据集,分别生成train、valid、train_valid文件夹
    在这些目录下面分别生成10个类别目录,遍历图片拷贝到对应的类别目录
    '''
    label_count={}#{'frog': 4500, 'cat': 4500, 'automobile': 4500,...}
    for train_file in os.listdir(os.path.join(data_dir,train_dir)):
        idx=int(train_file.split('.')[0])
        label=idx_label[idx]#类别
        mkdir_if_not_exist([data_dir,input_dir,'train_valid',label])
        src1=os.path.join(data_dir,train_dir,train_file)
        dst1=os.path.join(data_dir,input_dir,'train_valid',label)
        shutil.copy(src1,dst1)#将图片拷贝到train_valid_test\train_valid\类别\
        if label not in label_count or label_count[label]<n_train_per_label:
            mkdir_if_not_exist([data_dir,input_dir,'train',label])
            src2=os.path.join(data_dir,train_dir,train_file)
            dst2=os.path.join(data_dir,input_dir,'train',label)
            shutil.copy(src2,dst2)
            label_count[label]=label_count.get(label,0)+1#每个类别数量累加,小于n_train_per_label=4500
        else:
            mkdir_if_not_exist([data_dir,input_dir,'valid',label])
            src3=os.path.join(data_dir,train_dir,train_file)
            dst3=os.path.join(data_dir,input_dir,'valid',label)
            shutil.copy(src3,dst3)

input_dir='train_valid_test'
reorg_train_valid(data_dir,train_dir,input_dir,n_train_per_label,idx_label)

这个图片数量比较多,拷贝过程比较耗时,所以我们可以使用进度条来显示我们拷贝的进展。

from tqdm import tqdm
    with tqdm(total=len(os.listdir(os.path.join(data_dir,train_dir)))) as pbar:
        for train_file in tqdm(os.listdir(os.path.join(data_dir,train_dir))):
            ......

更多关于进度条的知识,可以参阅:Python中tqdm进度条的详细介绍(安装程序与耗时的迭代)最终结果是训练数据集的图片都拷贝到了各自所对应类别的目录里面。

1.3整理测试数据集

训练与验证的数据集做好,接下来做一个测试集用来预测的时候使用。

def reorg_test(data_dir,test_dir,input_dir):
    mkdir_if_not_exist([data_dir,input_dir,'test','unknown'])
    for test_file in os.listdir(os.path.join(data_dir,test_dir)):
        src=os.path.join(data_dir,test_dir,test_file)
        dst=os.path.join(data_dir,input_dir,'test','unknown')
        shutil.copy(src,dst)

reorg_test(data_dir,'test',input_dir)

这样就将dataset\test中的测试图片拷贝到了dataset\train_valid_test\test\unknown目录下面,当然简单起见直接手动拷贝过去也可以。

2、读取整理后的数据集

2.1、图像增广

为了应对过拟合,我们使用图像增广,关于图像增广在前面章节有讲过,有兴趣的也可以查阅:

计算机视觉之图像增广(翻转、随机裁剪、颜色变化[亮度、对比度、饱和度、色调])

这里我们将训练数据集做一些随机翻转、缩放裁剪与通道的标准化等处理,对测试与验证数据集只做个标准化处理

# 训练集图像增广
transform_train = gdata.vision.transforms.Compose([gdata.vision.transforms.Resize(40),
                                                   gdata.vision.transforms.RandomResizedCrop(32, scale=(0.64, 1.0), ratio=(1.0, 1.0)),
                                                   gdata.vision.transforms.RandomFlipLeftRight(),
                                                   gdata.vision.transforms.ToTensor(),
                                                   gdata.vision.transforms.Normalize([0.4914, 0.4822, 0.4465],[0.2023, 0.1994, 0.2010])
                                                   ])

#测试集图像增广
transform_test = gdata.vision.transforms.Compose([gdata.vision.transforms.ToTensor(),
                                                  gdata.vision.transforms.Normalize([0.4914, 0.4822, 0.4465],[0.2023, 0.1994, 0.2010])])

2.2、读取数据集

读取增广后的数据集,使用ImageFolderDataset实例来读取整理之后的文件夹里的图片数据集,其中每个数据样本包括图像和标签。

#ImageFolderDataset加载存储在文件夹结构中的图像文件的数据集
train_ds = gdata.vision.ImageFolderDataset(os.path.join(data_dir, input_dir, 'train'), flag=1)
valid_ds = gdata.vision.ImageFolderDataset(os.path.join(data_dir, input_dir, 'valid'), flag=1)
train_valid_ds = gdata.vision.ImageFolderDataset(os.path.join(data_dir, input_dir, 'train_valid'), flag=1)
test_ds = gdata.vision.ImageFolderDataset(os.path.join(data_dir, input_dir, 'test'), flag=1)
print(train_ds.items[0:2],train_ds.items[-2:])
'''
[('dataset\\train_valid_test\\train\\airplane\\10009.png', 0), ('dataset\\train_valid_test\\train\\airplane\\10011.png', 0)]
[('dataset\\train_valid_test\\train\\truck\\5235.png', 9), ('dataset\\train_valid_test\\train\\truck\\5236.png', 9)]
'''

打印的items来看,返回的是列表,里面的元素是元组对,分别是图片路径与标签(类别)值。

然后我们使用DataLoader实例,指定增广之后的数据集,返回小批量数据。在训练时,我们仅用验证集评价模型,因此需要保证输出的确定性。在预测时,我们将在训练集和验证集的并集上训练模型,以充分利用所有标注的数据。

#DataLoader从数据集中加载数据并返回小批量数据
batch_size = 128
train_iter = gdata.DataLoader(train_ds.transform_first(transform_train), batch_size, shuffle=True, last_batch='keep')
valid_iter = gdata.DataLoader(valid_ds.transform_first(transform_test), batch_size, shuffle=True, last_batch='keep')
train_valid_iter = gdata.DataLoader(train_valid_ds.transform_first(transform_train), batch_size, shuffle=True, last_batch='keep')
test_iter = gdata.DataLoader(test_ds.transform_first(transform_test), batch_size, shuffle=False, last_batch='keep')

3、定义模型

数据集处理好了之后,我们就可以开始定义合适的模型了,我们选用残差网络ResNet-18模型,在此之前我们先使用基于HybridBlock类构建残差块:

#定义残差块
class Residual(nn.HybridBlock):
    def __init__(self, num_channels, use_1x1conv=False, strides=1, **kwargs):
        super(Residual, self).__init__(**kwargs)
        self.conv1 = nn.Conv2D(num_channels, kernel_size=3,padding=1, strides=strides)
        self.conv2 = nn.Conv2D(num_channels, kernel_size=3, padding=1)
        if use_1x1conv:
            self.conv3 = nn.Conv2D(
                num_channels, kernel_size=1, strides=strides)
        else:
            self.conv3 = None
        self.bn1 = nn.BatchNorm()
        self.bn2 = nn.BatchNorm()

    def hybrid_forward(self, F, X):
        Y = F.relu(self.bn1(self.conv1(X)))
        Y = self.bn2(self.conv2(Y))
        if self.conv3:
            X = self.conv3(X)
        return F.relu(Y+X)

定义好了残差块,就可以方便的构建残差网络了。

#ResNet-18模型
def resnet18(num_classes):
    net = nn.HybridSequential()
    net.add(nn.Conv2D(64, kernel_size=3, strides=1, padding=1),nn.BatchNorm(), nn.Activation('relu'))

    def resnet_block(num_channels, num_residuals, first_block=False):
        blk = nn.HybridSequential()
        for i in range(num_residuals):
            if i == 0 and not first_block:
                blk.add(Residual(num_channels, use_1x1conv=True, strides=2))
            else:
                blk.add(Residual(num_channels))
        return blk

    net.add(resnet_block(64, 2, first_block=True), resnet_block(128, 2), resnet_block(256, 2), resnet_block(512, 2))
    net.add(nn.GlobalAvgPool2D(), nn.Dense(num_classes))
    return net

定义好了模型,在训练之前我们使用Xavier随机初始化,我们这里是CIFAR10数据集,有10个分类,所以最终的稠密层我们输出的是10:

def get_net(ctx):
    num_classes = 10
    net = resnet18(num_classes)
    net.initialize(ctx=ctx, init=init.Xavier())
    return net
loss=gloss.SoftmaxCrossEntropyLoss()

4、训练模型

模型初始化好了之后,就可以对其进行训练了,定义一个训练函数train:

def train(net, train_iter, valid_iter, num_epochs, lr, wd, ctx, lr_period, lr_decay):
    trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': lr, 'momentum': 0.9, 'wd': wd})
    for epoch in range(num_epochs):
        train_l_sum, train_acc_sum, n, start = 0.0, 0.0, 0, time.time()
        if epoch > 0 and epoch % lr_period == 0:
            trainer.set_learning_rate(trainer.learning_rate*lr_decay)
        for X, y in train_iter:
            y = y.astype('float32').as_in_context(ctx)
            with autograd.record():
                y_hat = net(X.as_in_context(ctx))
                l = loss(y_hat, y).sum()
            l.backward()
            trainer.step(batch_size)
            train_l_sum += l.asscalar()
            train_acc_sum += (y_hat.argmax(axis=1) == y).sum().asscalar()
            n += y.size
        time_s = "time %.2f sec" % (time.time()-start)
        if valid_iter is not None:
            # 评估给定数据集上模型的准确性《使用验证集》
            valid_acc = d2l.evaluate_accuracy(valid_iter, net, ctx)
            epoch_s = ("epoch %d,loss %f,train acc %f,valid acc %f," %
                       (epoch+1, train_l_sum/n, train_acc_sum/n, valid_acc))
        else:
            epoch_s = ("epoch %d,loss %f,train acc %f," %
                       (epoch+1, train_l_sum/n, train_acc_sum/n))
        print(epoch_s+time_s+',lr '+str(trainer.learning_rate))

定义好了train函数,就可以进行训练了

# 开始训练
ctx, num_epochs, lr, wd = d2l.try_gpu(), 1, 0.1, 5e-4
lr_period, lr_decay, net = 80, 0.1, get_net(ctx)
net.hybridize()
train(net, train_iter, valid_iter, num_epochs, lr, wd, ctx, lr_period, lr_decay)

这里我们可以简单的将num_epochs设置为1,只迭代一次看下程序有没有什么bug与运行的怎么样:

epoch 1,loss 2.033364,train acc 0.294133,valid acc 0.345600,time 288.89 sec,lr 0.1

运行是没有什么问题,接下来就正式进入到分类的主题了

5、测试集分类

模型训练没有什么问题,超参数什么的也设置好了,我们使用所有训练数据集(包括验证集)重新训练模型,对测试集进行分类,这里我使用5个迭代来看下效果会是怎么样的:

num_epochs, preds = 5, []
net.hybridize()
train(net, train_valid_iter, None, num_epochs,lr, wd, ctx, lr_period, lr_decay)
for X, _ in test_iter:
    y_hat = net(X.as_in_context(ctx))
    preds.extend(y_hat.argmax(axis=1).astype(int).asnumpy())
sorted_ids = list(range(1, len(test_ds)+1))
sorted_ids.sort(key=lambda x: str(x))
df = pd.DataFrame({'id': sorted_ids, 'label': preds})
df['label'] = df['label'].apply(lambda x: train_valid_ds.synsets[x])
df.to_csv('submission.csv', index=False)

'''
epoch 1,loss 2.192931,train acc 0.253960,time 346.49 sec,lr 0.1
epoch 2,loss 1.663164,train acc 0.390080,time 118.79 sec,lr 0.1
epoch 3,loss 1.493299,train acc 0.456140,time 118.91 sec,lr 0.1
epoch 4,loss 1.356744,train acc 0.509440,time 117.40 sec,lr 0.1
epoch 5,loss 1.235666,train acc 0.556580,time 114.41 sec,lr 0.1
'''

可以看到损失在降低,精度在增加,一切正常,训练完毕将生成一个提交文件:submission.csv

然后将这个submission.csv文件提交看下打分与排名,当然这里可以将迭代次数调大,准确度也是会上来的,我迭代了100次然后提交看下分数如何,结果还是不错的

附上全部代码:

import pandas as pd
import d2lzh as d2l
import os
from mxnet import autograd,gluon,init
from mxnet.gluon import data as gdata,loss as gloss,nn
import shutil
import time

def read_label_file(data_dir,label_file,train_dir,valid_ratio):
    '''
    读取训练集的标签文件
        参数
            valid_ratio:验证集样本数与原始训练集样本数之比
        返回值
            n_train // len(labels):每个类多少张图片
            idx_label:50000个id:label的字典
    '''
    with open(os.path.join(data_dir,label_file),'r') as f:
        lines=f.readlines()[1:]
        tokens=[l.rstrip().split(',') for l in lines]
        idx_label=dict(((int(idx),label) for idx,label in tokens))
    #{'cat', 'ship', 'frog', 'dog', 'truck', 'deer', 'horse', 'bird', 'airplane', 'automobile'}
    labels=set(idx_label.values())#去重就是10个类别
    n_train_valid=len(os.listdir(os.path.join(data_dir,train_dir)))#50000
    n_train=int(n_train_valid*(1-valid_ratio))
    assert 0<n_train<n_train_valid
    return n_train // len(labels),idx_label

data_dir,label_file="dataset","trainLabels.csv"
train_dir,valid_ratio="train",0.1

n_train_per_label,idx_label=read_label_file(data_dir,label_file,train_dir,valid_ratio)
#print(n_train_per_label,len(idx_label))

#辅助函数,路径不存在就创建
def mkdir_if_not_exist(path):
    if not os.path.exists(os.path.join(*path)):
        os.makedirs(os.path.join(*path))

def reorg_train_valid(data_dir,train_dir,input_dir,n_train_per_label,idx_label):
    '''
    切分训练数据集,分别生成train、valid、train_valid文件夹
    在这些目录下面分别生成10个类别目录,遍历图片拷贝到对应的类别目录
    '''
    label_count={}#{'frog': 4500, 'cat': 4500, 'automobile': 4500,...}
    from tqdm import tqdm
    with tqdm(total=len(os.listdir(os.path.join(data_dir,train_dir)))) as pbar:
        for train_file in tqdm(os.listdir(os.path.join(data_dir,train_dir))):
            idx=int(train_file.split('.')[0])
            label=idx_label[idx]#类别
            mkdir_if_not_exist([data_dir,input_dir,'train_valid',label])
            src1=os.path.join(data_dir,train_dir,train_file)
            dst1=os.path.join(data_dir,input_dir,'train_valid',label)
            #shutil.copy(src1,dst1)#将图片拷贝到train_valid_test\train_valid\类别\
        
            if label not in label_count or label_count[label]<n_train_per_label:
                mkdir_if_not_exist([data_dir,input_dir,'train',label])
                src2=os.path.join(data_dir,train_dir,train_file)
                dst2=os.path.join(data_dir,input_dir,'train',label)
                #shutil.copy(src2,dst2)
                label_count[label]=label_count.get(label,0)+1#每个类别数量累加,小于n_train_per_label=4500
            else:
                mkdir_if_not_exist([data_dir,input_dir,'valid',label])
                src3=os.path.join(data_dir,train_dir,train_file)
                dst3=os.path.join(data_dir,input_dir,'valid',label)
                #shutil.copy(src3,dst3)

input_dir='train_valid_test'
#reorg_train_valid(data_dir,train_dir,input_dir,n_train_per_label,idx_label)

def reorg_test(data_dir,test_dir,input_dir):
    mkdir_if_not_exist([data_dir,input_dir,'test','unknown'])
    for test_file in os.listdir(os.path.join(data_dir,test_dir)):
        src=os.path.join(data_dir,test_dir,test_file)
        dst=os.path.join(data_dir,input_dir,'test','unknown')
        #shutil.copy(src,dst)

#reorg_test(data_dir,'test',input_dir)


# 训练集图像增广
transform_train = gdata.vision.transforms.Compose([gdata.vision.transforms.Resize(40),
                                                   gdata.vision.transforms.RandomResizedCrop(32, scale=(0.64, 1.0), ratio=(1.0, 1.0)),
                                                   gdata.vision.transforms.RandomFlipLeftRight(),
                                                   gdata.vision.transforms.ToTensor(),
                                                   gdata.vision.transforms.Normalize([0.4914, 0.4822, 0.4465],[0.2023, 0.1994, 0.2010])
                                                   ])

#测试集图像增广
transform_test = gdata.vision.transforms.Compose([gdata.vision.transforms.ToTensor(),
                                                  gdata.vision.transforms.Normalize([0.4914, 0.4822, 0.4465],[0.2023, 0.1994, 0.2010])])

#读取增广后的数据集
#ImageFolderDataset加载存储在文件夹结构中的图像文件的数据集
train_ds = gdata.vision.ImageFolderDataset(os.path.join(data_dir, input_dir, 'train'), flag=1)
valid_ds = gdata.vision.ImageFolderDataset(os.path.join(data_dir, input_dir, 'valid'), flag=1)
train_valid_ds = gdata.vision.ImageFolderDataset(os.path.join(data_dir, input_dir, 'train_valid'), flag=1)
test_ds = gdata.vision.ImageFolderDataset(os.path.join(data_dir, input_dir, 'test'), flag=1)
#print(train_ds.items[0:2],train_ds.items[-2:])
'''
[('dataset\\train_valid_test\\train\\airplane\\10009.png', 0), ('dataset\\train_valid_test\\train\\airplane\\10011.png', 0)]
[('dataset\\train_valid_test\\train\\truck\\5235.png', 9), ('dataset\\train_valid_test\\train\\truck\\5236.png', 9)]
'''

#DataLoader从数据集中加载数据并返回小批量数据
batch_size = 128
train_iter = gdata.DataLoader(train_ds.transform_first(transform_train), batch_size, shuffle=True, last_batch='keep')
valid_iter = gdata.DataLoader(valid_ds.transform_first(transform_test), batch_size, shuffle=True, last_batch='keep')
train_valid_iter = gdata.DataLoader(train_valid_ds.transform_first(transform_train), batch_size, shuffle=True, last_batch='keep')
test_iter = gdata.DataLoader(test_ds.transform_first(transform_test), batch_size, shuffle=False, last_batch='keep')


#-----------------定义模型--------------------
#定义残差块
class Residual(nn.HybridBlock):
    def __init__(self, num_channels, use_1x1conv=False, strides=1, **kwargs):
        super(Residual, self).__init__(**kwargs)
        self.conv1 = nn.Conv2D(num_channels, kernel_size=3, padding=1, strides=strides)
        self.conv2 = nn.Conv2D(num_channels, kernel_size=3, padding=1)
        if use_1x1conv:
            self.conv3 = nn.Conv2D(num_channels, kernel_size=1, strides=strides)
        else:
            self.conv3 = None
        self.bn1 = nn.BatchNorm()
        self.bn2 = nn.BatchNorm()

    def hybrid_forward(self, F, X):
        Y = F.relu(self.bn1(self.conv1(X)))
        Y = self.bn2(self.conv2(Y))
        if self.conv3:
            X = self.conv3(X)
        return F.relu(Y+X)

#ResNet-18模型
def resnet18(num_classes):
    net = nn.HybridSequential()
    net.add(nn.Conv2D(64, kernel_size=3, strides=1, padding=1),nn.BatchNorm(), nn.Activation('relu'))

    def resnet_block(num_channels, num_residuals, first_block=False):
        blk = nn.HybridSequential()
        for i in range(num_residuals):
            if i == 0 and not first_block:
                blk.add(Residual(num_channels, use_1x1conv=True, strides=2))
            else:
                blk.add(Residual(num_channels))
        return blk

    net.add(resnet_block(64, 2, first_block=True), resnet_block(128, 2), resnet_block(256, 2), resnet_block(512, 2))
    net.add(nn.GlobalAvgPool2D(), nn.Dense(num_classes))
    return net


def get_net(ctx):
    num_classes = 10
    net = resnet18(num_classes)
    net.initialize(ctx=ctx, init=init.Xavier())
    return net
loss=gloss.SoftmaxCrossEntropyLoss()

#---------------------训练函数---------------------
def train(net, train_iter, valid_iter, num_epochs, lr, wd, ctx, lr_period, lr_decay):
    trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': lr, 'momentum': 0.9, 'wd': wd})
    for epoch in range(num_epochs):
        train_l_sum, train_acc_sum, n, start = 0.0, 0.0, 0, time.time()
        if epoch > 0 and epoch % lr_period == 0:
            trainer.set_learning_rate(trainer.learning_rate*lr_decay)
        for X, y in train_iter:
            y = y.astype('float32').as_in_context(ctx)
            with autograd.record():
                y_hat = net(X.as_in_context(ctx))
                l = loss(y_hat, y).sum()
            l.backward()
            trainer.step(batch_size)
            train_l_sum += l.asscalar()
            train_acc_sum += (y_hat.argmax(axis=1) == y).sum().asscalar()
            n += y.size
        time_s = "time %.2f sec" % (time.time()-start)
        if valid_iter is not None:
            # 评估给定数据集上模型的准确性《使用验证集》
            valid_acc = d2l.evaluate_accuracy(valid_iter, net, ctx)
            epoch_s = ("epoch %d,loss %f,train acc %f,valid acc %f," %
                       (epoch+1, train_l_sum/n, train_acc_sum/n, valid_acc))
        else:
            epoch_s = ("epoch %d,loss %f,train acc %f," %
                       (epoch+1, train_l_sum/n, train_acc_sum/n))
        print(epoch_s+time_s+',lr '+str(trainer.learning_rate))


# 开始训练
ctx, num_epochs, lr, wd = d2l.try_gpu(), 1, 0.1, 5e-4
lr_period, lr_decay, net = 80, 0.1, get_net(ctx)
#net.hybridize()
#train(net, train_iter, valid_iter, num_epochs, lr, wd, ctx, lr_period, lr_decay)

num_epochs, preds = 100, []
net.hybridize()
train(net, train_valid_iter, None, num_epochs,lr, wd, ctx, lr_period, lr_decay)
for X, _ in test_iter:
    y_hat = net(X.as_in_context(ctx))
    preds.extend(y_hat.argmax(axis=1).astype(int).asnumpy())
sorted_ids = list(range(1, len(test_ds)+1))
sorted_ids.sort(key=lambda x: str(x))
df = pd.DataFrame({'id': sorted_ids, 'label': preds})
#apply应用synsets方法,将0~9的数字分别转换为airplane、automobile...对应的类别
#synsets方法大家可以看定义,就是获取文件夹名称(类别)
df['label'] = df['label'].apply(lambda x: train_valid_ds.synsets[x])
df.to_csv('submission.csv', index=False)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/358417.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

spring cloud gateway集成sentinel并扩展支持restful api进行url粒度的流量治理

sentinel集成网关支持restful接口进行url粒度的流量治理前言使用网关进行总体流量治理&#xff08;sentinel版本&#xff1a;1.8.6&#xff09;1、cloud gateway添加依赖:2、sentinel配置3、网关类型项目配置4、通过zk事件监听刷新上报api分组信息1、非网关项目上报api分组信息…

I.MX6ULL_Linux_系统篇(16) uboot分析-启动流程

原文链接&#xff1a;I.MX6ULL_系统篇(16) uboot分析-启动流程 – WSY Personal Blog (cpolar.cn) 前面我们详细的分析了 uboot 的顶层 Makefile&#xff0c;了解了 uboot 的编译流程。本章我们来详细的分析一下 uboot 的启动流程&#xff0c;理清 uboot 是如何启动的。通过对 …

虹科资讯| 虹科AR荣获汽车后市场“20佳”维修工具评委会提名奖!

2022 虹科荣获20佳维修工具 评委会提名奖 特大喜讯&#xff0c;在2月16日《汽车维修与保养》杂志主办的第十八届汽车后市场“20佳”评选活动中&#xff0c;虹科的产品“M400智能AR眼镜”凭借在AR领域的专业实力&#xff0c;通过层层筛选&#xff0c;在102款入围产品中脱颖而出…

GIT:【基础三】Git工作核心原理

目录 一、Git本地四个工作区域 二、Git提交文件流程 一、Git本地四个工作区域 工作目录(Working Directory)&#xff1a;电脑上存放开发代码的地方。暂存区(Stage/Index)&#xff1a;用于l临时存放改动的文件&#xff0c;本质上只是一个文件&#xff0c;保存即将提交到文件列…

[ 对比学习篇 ] 经典网络模型 —— Contrastive Learning

&#x1f935; Author &#xff1a;Horizon Max ✨ 编程技巧篇&#xff1a;各种操作小结 &#x1f3c6; 神经网络篇&#xff1a;经典网络模型 &#x1f4bb; 算法篇&#xff1a;再忙也别忘了 LeetCode [ 对比学习篇 ] 经典网络模型 —— Contrastive Learning&#x1f680; …

MongoDB介绍及使用教程

文章目录一、MongoDB介绍1. 什么是MongoDB2. 为什么要用MongoDB3. MongoDB的应用场景4. MongoDB基本概念二、MongoDB使用教程1.下载安装&#xff08;Windows&#xff09;2.MongoDB Conpass简单使用&#xff08;选学&#xff09;3.使用navicat连接MongoDB4.JAVA项目中使用MongoD…

JVM11 垃圾回收

1.1GC分类与性能指标 垃圾收集器没有在规范中进行过多的规定&#xff0c;可以由不同的厂商、不同版本的JVM来实现。 从不同角度分析垃圾收集器&#xff0c;可以将GC分为不同的类型。 Java不同版本新特性 语法层面&#xff1a;Lambda表达式、switch、自动拆箱装箱、enumAPI层面…

AI稳定生成图工业链路打造

前沿这篇文章会以比较轻松的方式&#xff0c;跟大家交流下如何控制文本生成图片的质量。要知道如何控制文本生成质量&#xff0c;那么我们首先需要知道我们有哪些可以控制的参数和模块。要知道我们有哪些控制的参数和模块&#xff0c;我们就得知道我们文本生成图片的这架机器或…

新手福利——x64逆向基础

一、x64程序的内存和通用寄存器 随着游戏行业的发展&#xff0c;x32位的程序已经很难满足一些新兴游戏的需求了&#xff0c;因为32位内存的最大值为0xFFFFFFFF&#xff0c;这个值看似足够&#xff0c;但是当游戏对资源需求非常大&#xff0c;那么真正可以分配的内存就显得捉襟…

测试人员如何运用好OKR

在软件测试工作中是不是还不知道OKR是什么?又或者每次都很害怕写OKR?或者总觉得很迷茫&#xff0c;不知道目标是什么? OKR 与 KPI 的区别 去年公司从KPI换OKR之后&#xff0c;我也有一段抓瞎的过程&#xff0c;然后自己找了两本书看&#xff0c;一本是《OKR工作法》&#xf…

WPF_ObservableCollection基本使用及其注意项

文章目录一、引言二、ObservableCollection三、结语一、引言 在GUI编程中经常会用到条目控件&#xff0c;常见的如ComboBox&#xff08;下拉列表框&#xff09;&#xff0c;它内部往往有多个项。 在使用一些图形框架&#xff08;Qt、WinForm&#xff09;上进行原始开发时&…

安卓mvvm

AndroidX的意思是android extension libraries, 也就是安卓扩展包 AndroidX其实是Jetpack类库的命名空间 (190条消息) AndroidX初识_Neda Wang的博客-CSDN博客https://blog.csdn.net/weixin_38261570/article/details/111500044 viewmodel ViewModel类旨在以注重生命周期的方…

【机器学习】决策树-C4.5算法

1.C4.5算法 C4.5算法与ID3相似&#xff0c;在ID3的基础上进行了改进&#xff0c;采用信息增益比来选择属性。ID3选择属性用的是子树的信息增益&#xff0c;ID3使用的是熵&#xff08;entropy&#xff0c; 熵是一种不纯度度量准则&#xff09;&#xff0c;也就是熵的变化值&…

回溯算法理论基础及组合问题

文章目录回溯算法理论基础什么是回溯法回溯法的效率回溯法解决的问题如何理解回溯法回溯法模板组合问题回溯算法理论基础 什么是回溯法 回溯法也可以叫做回溯搜索法&#xff0c;它是一种搜索的方式。 回溯是递归的副产品&#xff0c;只要有递归就会有回溯。 所以以下讲解中&…

LPWAN及高效弹性工业物联网核心技术方案

20多年前的一辆拖拉机就是一个纯机械的产品&#xff0c;里面可能并没有电子或者软件的构成&#xff1b;而随后随着软件的发展&#xff0c;拖拉机中嵌入了软件&#xff0c;它能控制发动机的功率及拖拉机防抱死系统&#xff1b;接下来&#xff0c;通过融入各种软件&#xff0c;拖…

js逆向基础篇-某房地产网站-登录

提示!本文章仅供学习交流,严禁用于任何商业和非法用途,如有侵权,可联系本文作者删除! 网站链接:aHR0cHM6Ly9tLmZhbmcuY29tL215Lz9jPW15Y2VudGVyJmE9aW5kZXgmY2l0eT1iag== 案例分析: 本篇文章分析的是登录逻辑。话不多说,先看看登录中有哪些加密参数,在登录页面随便输入…

K8S DNS解析过程和延迟问题

一、Linux DNS查询解析原理&#xff08;对于调用glibc库函数gethostbyname的程序&#xff09;我们在浏览器访问www.baidu.com这个域名&#xff0c;dns怎么查询到这台主机呢&#xff1f;  1、在浏览器中输入www.baidu.com域名&#xff0c;操作系统会先查找本地DNS解析器缓存&a…

实例2:树莓派GPIO控制外部LED灯闪烁

实例2&#xff1a;树莓派GPIO控制外部LED灯闪烁 实验目的 通过背景知识学习&#xff0c;了解四足机器人mini pupper搭载的微型控制计算机&#xff1a;树莓派。通过树莓派GPIO操作的学习&#xff0c;熟悉GPIO的读写控制。通过外部LED灯的亮灭控制&#xff0c;熟悉树莓派对外界…

vue3 + vite 使用 svg 可改变颜色

文章目录vue3 vite 使用 svg安装插件2、配置插件 vite.config.js3、根据vite配置的svg图标文件夹&#xff0c;建好文件夹&#xff0c;把svg图标放入4、在 src/main.js内引入注册脚本5、创建一个公共SvgIcon.vue组件6.1 全局注册SvgIcon.vue组件6.2、在想要引入svg的vue组件中引…

Boom 3D最新版本下载电脑音频增强应用工具

为了更好地感受音乐的魅力&#xff0c;Boom 3D 可以让你对音效进行个性化增强&#xff0c;并集成 3D 环绕立体声效果&#xff0c;可以让你在使用任何耳机时&#xff0c;都拥有纯正、优质的音乐体验。Boom 3D是一款充满神奇魅力的3D环绕音效升级版&#xff0c;BOOM 3D是一个全新…