【Pytorch项目实战】之迁移学习:特征提取、微调、特征提取+微调、雾霾清除

news2025/1/11 5:46:11

文章目录

  • 迁移学习(Transfer Learning)
    • 方法一:特征提取(Feature Extraction)
    • 方法二:微调(Fine Tuning)
    • (一)实战:基于特征提取的迁移学习(数据集:CIFAR-10)
    • (二)实战:基于微调的迁移学习(数据集:CIFAR-10)
    • (三)实战:基于迁移学习的图像雾霾清除
    • (四)实战:基于迁移学习的102种花分类(先特征提取,后微调)
    • (五)实战:基于迁移学习的102种花分类(自定义Dataset)

迁移学习(Transfer Learning)

迁移学习是一种机器学习方法。

  • 具体过程:把任务A预训练模型(网络结构与权重参数),迁移到任务B上。A任务可以是识别图像中的车辆,而B任务可以是识别卡车、汽车、公交车等。
    在这里插入图片描述
  • 优点:加速训练过程,提升深度模型的性能。
  • 应用:常用于大数据,深网络。如:计算机视觉、自然语言处理。
  • 主要有三种方法:特征提取、微调、特征提取+微调

方法一:特征提取(Feature Extraction)

主要步骤:

  • 11、冻结除最后一个全连接层之外的所有网络的权重参数(即取消梯度更新:requires_grad=False);
  • 22、依据实际任务修改最后一个全连接层的分类器,随机初始化其参数,然后仅训练该层网络。
    在这里插入图片描述

方法二:微调(Fine Tuning)

  • 主要步骤:在预训练模型上,添加新的随机初始化层,且不冻结预训练模型的网络参数,但会使用较小的学习率
  • 优点:
    (1)虽然训练时间更长但精度更高;
    (2)可以减少训练参数的数量,避免过拟合。
  • 常用方法:固定底层的参数,调整顶层或具体层的参数。
    在这里插入图片描述

(一)实战:基于特征提取的迁移学习(数据集:CIFAR-10)

链接:https://pan.baidu.com/s/18Bzu-MU_RS594QCZ8JJQEQ?pwd=efz6
提取码:efz6

在这里插入图片描述

import torch
import torchvision
import torchvision.transforms as transforms
from datetime import datetime
import matplotlib.pyplot as plt
import numpy as np
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'     # "OMP: Error #15: Initializing libiomp5md.dll"
#############################################################


def imshow(img):
    """显示图像"""
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()


def get_acc(output, label):
    """计算准确度"""
    total = output.shape[0]
    _, pred_label = output.max(1)
    num_correct = (pred_label == label).sum().item()
    return num_correct / total


def train(net, train_data, valid_data, num_epochs, optimizer, criterion):
    """模型训练"""
    prev_time = datetime.now()

    for epoch in range(num_epochs):
        train_loss = 0
        train_acc = 0
        net = net.train()       # 训练模型
        for im, label in train_data:
            im = im.to(device)          # (bs, 3, h, w)
            label = label.to(device)    # (bs, h, w)
            output = net(im)                    # 前向传播
            loss = criterion(output, label)     # 损失函数
            optimizer.zero_grad()               # 梯度清零
            loss.backward()                     # 后向传播
            optimizer.step()                    # 梯度更新

            train_loss += loss.item()
            train_acc += get_acc(output, label)

        # 打印运行时间
        cur_time = datetime.now()
        h, remainder = divmod((cur_time - prev_time).seconds, 3600)
        m, s = divmod(remainder, 60)
        time_str = "Time %02d:%02d:%02d" % (h, m, s)

        if valid_data is not None:
            valid_loss = 0
            valid_acc = 0
            net = net.eval()        # 验证模型
            for im, label in valid_data:
                im = im.to(device)          # (bs, 3, h, w)
                label = label.to(device)    # (bs, h, w)
                output = net(im)                        # 前向传播
                loss = criterion(output, label)         # 损失函数
                valid_loss += loss.item()
                valid_acc += get_acc(output, label)
            # 每个Epoch,打印结果。
            epoch_str = ("Epoch %d. Train Loss: %f, Train Acc: %f, Valid Loss: %f, Valid Acc: %f, " %
                         (epoch, train_loss / len(train_data), train_acc / len(train_data), valid_loss / len(valid_data), valid_acc / len(valid_data)))
        else:
            epoch_str = ("Epoch %d. Train Loss: %f, Train Acc: %f, " %
                         (epoch, train_loss / len(train_data), train_acc / len(train_data)))
        prev_time = cur_time
        print(epoch_str + time_str)


#############################################################
if __name__ == '__main__':
    # (1)下载数据、数据预处理、迭代器
    trans_train = transforms.Compose(
        [transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(),
         transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
    trans_valid = transforms.Compose(
        [transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(),
         transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
    trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=trans_train)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True, num_workers=2)
    testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=trans_valid)
    testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False, num_workers=2)
    classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
    #############################################################
    # (2)随机获取部分训练数据
    dataiter = iter(trainloader)
    images, labels = dataiter.next()
    imshow(torchvision.utils.make_grid(images[:4]))                     # 显示图像
    print(' '.join('%5s' % classes[labels[j]] for j in range(4)))       # 打印标签
    #############################################################
    # (3)冻住模型的所有权重参数
    net = torchvision.models.resnet18(pretrained=True)      # 使用预训练的模型
    for param in net.parameters():
        param.requires_grad = False             # 冻住该模型的所有权重参数
    #############################################################
    # (4)替换最后一层全连接层
    device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")     # 检测是否有可用的GPU,有则使用,否则使用CPU。
    net.fc = torch.nn.Linear(512, 10)           # 将最后的全连接层改成十分类

    # 查看总参数及(全连接层)训练参数
    total_params = sum(p.numel() for p in net.parameters())
    print('总参数个数:{}'.format(total_params))
    total_trainable_params = sum(p.numel() for p in net.parameters() if p.requires_grad)
    print('需训练参数个数:{}'.format(total_trainable_params))
    #############################################################
    # (5)(只)训练全连接层权重参数
    net = net.to(device)        # 将构建的张量或者模型分配到相应的设备上。
    criterion = torch.nn.CrossEntropyLoss()         # 交叉熵损失函数
    optimizer = torch.optim.SGD(net.fc.parameters(), lr=1e-3, weight_decay=1e-3, momentum=0.9)      # 优化器(学习率降低)
    train(net, trainloader, testloader, 1, optimizer, criterion)

(二)实战:基于微调的迁移学习(数据集:CIFAR-10)

链接:https://pan.baidu.com/s/18Bzu-MU_RS594QCZ8JJQEQ?pwd=efz6
提取码:efz6

在这里插入图片描述

import torch
from torch import nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torchvision import models
from torchvision.datasets import ImageFolder
from datetime import datetime
import matplotlib.pyplot as plt
import numpy as np
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'     # "OMP: Error #15: Initializing libiomp5md.dll"
#############################################################


def imshow(img):
    """显示图像"""
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()


def get_acc(output, label):
    """计算准确度"""
    total = output.shape[0]
    _, pred_label = output.max(1)
    num_correct = (pred_label == label).sum().item()
    return num_correct / total


def train(net, train_data, valid_data, num_epochs, optimizer, criterion):
    """模型训练"""
    prev_time = datetime.now()

    for epoch in range(num_epochs):
        train_loss = 0
        train_acc = 0
        net = net.train()       # 训练模型
        for im, label in train_data:
            im = im.to(device)          # (bs, 3, h, w)
            label = label.to(device)    # (bs, h, w)
            output = net(im)                    # 前向传播
            loss = criterion(output, label)     # 损失函数
            optimizer.zero_grad()               # 梯度清零
            loss.backward()                     # 后向传播
            optimizer.step()                    # 梯度更新

            train_loss += loss.item()
            train_acc += get_acc(output, label)

        # 打印运行时间
        cur_time = datetime.now()
        h, remainder = divmod((cur_time - prev_time).seconds, 3600)
        m, s = divmod(remainder, 60)
        time_str = "Time %02d:%02d:%02d" % (h, m, s)

        if valid_data is not None:
            valid_loss = 0
            valid_acc = 0
            net = net.eval()        # 验证模型
            for im, label in valid_data:
                im = im.to(device)          # (bs, 3, h, w)
                label = label.to(device)    # (bs, h, w)
                output = net(im)                        # 前向传播
                loss = criterion(output, label)         # 损失函数
                valid_loss += loss.item()
                valid_acc += get_acc(output, label)
            # 每个Epoch,打印结果。
            epoch_str = ("Epoch %d. Train Loss: %f, Train Acc: %f, Valid Loss: %f, Valid Acc: %f, " %
                         (epoch, train_loss / len(train_data), train_acc / len(train_data), valid_loss / len(valid_data), valid_acc / len(valid_data)))
        else:
            epoch_str = ("Epoch %d. Train Loss: %f, Train Acc: %f, " %
                         (epoch, train_loss / len(train_data), train_acc / len(train_data)))
        prev_time = cur_time
        print(epoch_str + time_str)


#############################################################
if __name__ == '__main__':
    # (1)下载数据、数据预处理、迭代器
    trans_train = transforms.Compose(
        [transforms.RandomResizedCrop(size=256, scale=(0.8, 1.0)), transforms.RandomRotation(degrees=15), transforms.ColorJitter(),
         transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(),
         transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
    trans_valid = transforms.Compose(
        [transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(),
         transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
    trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=trans_train)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True, num_workers=2)
    testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=trans_valid)
    testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False, num_workers=2)
    classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
    # (2)随机获取部分训练数据
    dataiter = iter(trainloader)
    images, labels = dataiter.next()
    imshow(torchvision.utils.make_grid(images))                             # 显示图像
    print(' '.join('%5s' % classes[labels[j]] for j in range(4)))           # 打印标签
    # (3)使用预训练的模型,并替换最后一层全连接层
    net = models.resnet18(pretrained=True)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")     # 检测是否有可用的GPU,有则使用,否则使用CPU。
    net.fc = nn.Linear(512, 10)         # 将最后的全连接层改成十分类
    # (4)模型训练
    net = net.to(device)        # 将构建的张量或者模型分配到相应的设备上。
    criterion = torch.nn.CrossEntropyLoss()         # 交叉熵损失函数
    optimizer = torch.optim.SGD(net.fc.parameters(), lr=1e-3, weight_decay=1e-3, momentum=0.9)      # 优化器(学习率降低)
    train(net, trainloader, testloader, 1, optimizer, criterion)

(三)实战:基于迁移学习的图像雾霾清除

链接:https://pan.baidu.com/s/1z1MKgoKc4T-iyJV4mFLHeg?pwd=y58o
提取码:y58o

在这里插入图片描述

import torch
import torch.nn as nn
import torchvision
import torch.backends.cudnn as cudnn
import torch.optim
import numpy as np
from torchvision import transforms
from PIL import Image
import glob
import matplotlib.pyplot as plt
from matplotlib.image import imread
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'     # "OMP: Error #15: Initializing libiomp5md.dll"
#############################################################
# 创建存放目标文件目录(如果文件不存在,则创建)
path = 'clean_photo/results'
if not os.path.exists(path):
    os.makedirs(path)


class model(nn.Module):
    """定义神经网络"""
    def __init__(self):
        super(model, self).__init__()
        self.relu = nn.ReLU(inplace=True)
        self.e_conv1 = nn.Conv2d(3, 3, 1, 1, 0, bias=True)
        self.e_conv2 = nn.Conv2d(3, 3, 3, 1, 1, bias=True)
        self.e_conv3 = nn.Conv2d(6, 3, 5, 1, 2, bias=True)
        self.e_conv4 = nn.Conv2d(6, 3, 7, 1, 3, bias=True)
        self.e_conv5 = nn.Conv2d(12, 3, 3, 1, 1, bias=True)

    def forward(self, x):
        source = []
        source.append(x)
        x1 = self.relu(self.e_conv1(x))
        x2 = self.relu(self.e_conv2(x1))
        concat1 = torch.cat((x1, x2), 1)
        x3 = self.relu(self.e_conv3(concat1))

        concat2 = torch.cat((x2, x3), 1)
        x4 = self.relu(self.e_conv4(concat2))
        concat3 = torch.cat((x1, x2, x3, x4), 1)
        x5 = self.relu(self.e_conv5(concat3))
        clean_image = self.relu((x5 * x) - x5 + 1)
        return clean_image


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")  # 检测是否有可用的GPU,有则使用,否则使用CPU。
net = model().to(device)


def cl_image(image_path):
    data = Image.open(image_path)
    data = (np.asarray(data) / 255.0)
    data = torch.from_numpy(data).float()
    data = data.permute(2, 0, 1)
    data = data.to(device).unsqueeze(0)

    ##########################################################
    # 加载预训练模型的权重参数
    net.load_state_dict(torch.load('clean_photo/dehazer.pth', map_location=torch.device('cpu')))     # CPU加载模型
    # net.load_state_dict(torch.load('dehazer.pth'))        # GPU加载模型
    ##########################################################
    clean_image = net.forward(data)     # 前向传播
    # 保存图像(自定义保存地址)
    torchvision.utils.save_image(torch.cat((data, clean_image), 0), "clean_photo/" + image_path.split("/")[-1])
    # split("/")[-1]: 获取分隔符最后一个字符串


if __name__ == '__main__':
    test_list = glob.glob(r"clean_photo/test_images\*")
    for image in test_list:
        cl_image(image)
        print(image, "done!")
    img = imread('./clean_photo/test_images/canyon.png')
    plt.imshow(img)
    plt.show()
    

(四)实战:基于迁移学习的102种花分类(先特征提取,后微调)

链接:https://pan.baidu.com/s/1nzV0_PorIupFVXlePoTzsw?pwd=ni9i
提取码:ni9i


PyTorch深度学习模型的保存和加载
CPU与GPU加载模型的区别:torch.load()

在这里插入图片描述

花朵数据分训练集和测试集,而每个数据集下有102个文件夹,分别存放102种花,文件夹的名字即对应花朵的标签。
在这里插入图片描述

import os
import time
import copy
import json
import matplotlib.pyplot as plt
import numpy as np
import torch
# from torch import nn
# import torch.optim as optim
# import torchvision
from torchvision import transforms, models, datasets

os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'     		# "OMP: Error #15: Initializing libiomp5md.dll"
###################################################################################################
# 迁移学习:即建立在已经训练好的网络模型(权重参数)基础上,继续训练。
# 		———— torchvision提供了很多经典网络模型。
# 		注意1:在已训练好模型基础上,将(最后一层)全连接层的权重参数,根据实际任务需要重新训练。比如:需要将10分类更新为50分类;
# 		注意2:<11>可以全部重头训练;<22>只训练咱们任务的最后一层,因为前几层都是做特征提取,任务目标是一致的。
#
# 模型训练步骤如下:
# 		(1)【模块1】:提取全连接层之前的网络模型(权重参数),且设置权重参数不更新(即冻住该模型);
# 		(2)【模块2】:根据实际任务需要,自定义全连接层的权重参数;再搭配上【模块1】,进行(全连接层的权重参数)训练。
# 		(3)基于【模块1】、【模块2】构建网络模型,且设置权重参数更新,进行训练。
###################################################################################################


# 网络模型初始化(下载torchvision已经训练好的网络模型和权重参数)
def initialize_model(model_name, classes_num, feature_extract=True, use_pretrained=True):
	model_ft = models.resnet18(pretrained=use_pretrained)		# 模型初始化
	if feature_extract:											# 是否更新模型参数
		for param in model_ft.parameters():
			param.requires_grad = False							# 提取已经训练好的权重参数(不再更新)
	num_ftrs = model_ft.fc.in_features
	model_ft.fc = torch.nn.Linear(num_ftrs, classes_num)  		# 全连接网络:设置分类数目(根据实际任务)
	input_size = 64  											# 设置输入图像大小(根据实际任务)
	return model_ft, input_size


def train_model(model, dataloaders, optimizer, criterion, num_epochs, filename):
	since = time.time()								# 统计运行时间
	best_acc = 0									# 最优精确度
	model.to(device)								# 加载模型到CPU/GPU
	val_acc_history = []							# 验证集历史精确度
	train_acc_history = []							# 验证集历史精确度
	train_losses = []								# 验证集损失值
	valid_losses = []								# 验证集损失值
	lr_s = [optimizer.param_groups[0]['lr']]		# 学习率
	best_model_wts = copy.deepcopy(model.state_dict())				# 最好的那次模型,后续会变的,先初始化

	for epoch in range(num_epochs):
		print('-' * 50)												# 切割字符串
		print('Epoch = {}/{}'.format(epoch, num_epochs - 1))		# 打印当前第几轮epoch

		# 训练模型和验证模型
		for phase in ['train', 'valid']:
			if phase == 'train':
				model.train()  		# 切换训练模型
			else:
				model.eval()  		# 切换验证模型

			running_loss = 0.0			# 单个epoch的损失
			running_corrects = 0		# 单个epoch的准确率
			for inputs, labels in dataloaders[phase]:		# 遍历(训练集和验证集)
				inputs = inputs.to(device)  			# (图像)加载到CPU或GPU中
				labels = labels.to(device)				# (标签)加载到CPU或GPU中
				optimizer.zero_grad()					# 梯度清零
				outputs = model(inputs)					# 前向传播(每个图像输出N个值,对应N分类)
				loss = criterion(outputs, labels)		# 损失函数
				_, preds = torch.max(outputs, 1)		# 预测结果(取最大概率值对应的分类结果)

				# 梯度更新(仅限训练阶段)
				if phase == 'train':
					loss.backward()						# 后向传播
					optimizer.step()					# 参数更新

				# 计算损失
				running_loss += loss.item() * inputs.size(0)  				# 累加当前batch的损失值。0表示batch维度
				running_corrects += torch.sum(preds == labels.data)  		# 累加当前batch的准确率(预测结果最大值和真实值是否一致)

			epoch_loss = running_loss / len(dataloaders[phase].dataset) 					# 计算每个epoch平均损失
			epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)			# 计算每个epoch准确度
			time_elapsed = time.time() - since  							# 计算一个epoch计算时间
			print('Time_elapsed: {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
			print('{}_loss: {:.4f} 		acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))

			# 验证模型。提取精确度最高的模型(迭代训练,可能会过拟合)
			if phase == 'valid' and epoch_acc > best_acc:
				best_acc = epoch_acc									# 记录最优精确度
				best_model_wts = copy.deepcopy(model.state_dict())		# 复制当前最好的权重参数
				# 权重参数(字典结构:key是每个网络层的名字,value是权重参数) + 最优准确度 + 优化器参数(lr)
				state = {'state_dict': model.state_dict(), 'best_acc': best_acc, 'optimizer': optimizer.state_dict()}
				torch.save(state, filename)								# 保存(当前模型)训练好的权重参数

			if phase == 'valid':
				val_acc_history.append(epoch_acc)
				valid_losses.append(epoch_loss)
			if phase == 'train':
				train_acc_history.append(epoch_acc)
				train_losses.append(epoch_loss)

		print('learning_rate : {:.7f}'.format(optimizer.param_groups[0]['lr']))
		lr_s.append(optimizer.param_groups[0]['lr'])		# 保存当前epoch的学习率
		optimizer.step()  					# 参数更新

	print('*' * 50)							# 切割字符串
	time_elapsed = time.time() - since		# 训练总时间
	print('Training_total_time {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
	print('Best_acc(valid): {:4f}'.format(best_acc))

	model.load_state_dict(best_model_wts)		# 训练完后,提取最优准确度对应的网络模型权重参数。
	return model, val_acc_history, train_acc_history, valid_losses, train_losses, lr_s


def im_convert(tensor):
	""" 展示数据 """
	image = tensor.to("cpu").clone().detach()		# 将torch.tensor提取到cpu下
	image = image.numpy().squeeze()					# 并转换为numpy格式,且降维处理
	image = image.transpose(1, 2, 0)				# 维度转换:size * channel(torch中的图像维度:channel * size)
	image = image * np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406))		# 预处理操作改变源图像,需还原
	image = image.clip(0, 1)
	return image


#################################################################################################################
# (1)指定数据文件的地址
data_dir = './flower_data/'
train_dir = data_dir + '/train'
valid_dir = data_dir + '/valid'
# 数据分为训练集和测试集两个文件夹:每个文件夹下有102个子文件夹,每个子文件夹下存放对应类别的图像。
#################################################################################################################
# (2)数据增强
data_transforms = {
	'train': transforms.Compose([
		transforms.Resize([96, 96]),				# Resize的作用是对图像进行缩放。
		transforms.RandomRotation(45),				# RandomRotation的作用是对图像进行随机旋转。
		transforms.CenterCrop(64),					# CenterCrop的作用是从图像的中心位置裁剪指定大小的图像
		transforms.RandomHorizontalFlip(p=0.5),		# RandomHorizontalFlip的作用是以一定的概率对图像进行水平翻转。
		transforms.RandomVerticalFlip(p=0.5),		# RandomVerticalFlip的作用是以一定的概率对图像进行垂直翻转。
		transforms.ColorJitter(brightness=0.2, contrast=0.1, saturation=0.1, hue=0.1),		# ColorJitter的作用是随机修改图片的亮度、对比度和饱和度
		transforms.RandomGrayscale(p=0.025),		# RandomGrayscale的作用是以一定的概率将图像变为灰度图像。
		transforms.ToTensor(),						# 将PIL Image或numpy.ndarray转为pytorch的Tensor,并会将像素值由[0, 255]变为[0, 1]之间。
		transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),				# Normalize的作用是用均值和标准差对Tensor进行归一化处理。
	'valid': transforms.Compose([
		transforms.Resize([64, 64]),
		transforms.ToTensor(),
		transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
}
#################################################################################################################
# (3)数据预处理
batch_size = 128
# 添加(指定数据)路径 + 数据增强。		ImageFolder:提取多个文件夹下的数据(训练集+测试集)		DataLoader:批量数据读取
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x]) for x in ['train', 'valid']}
# 字典结构:{'train': x, 'valid': x}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True) for x in ['train', 'valid']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'valid']}
class_names = image_datasets['train'].classes
#################################################################################################################
# (4)判断GPU是否可用
train_on_gpu = torch.cuda.is_available()
if not train_on_gpu:
	print('CUDA is not available.  Training on CPU ...')
else:
	print('CUDA is available!  Training on GPU ...')
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
#################################################################################################################
# (5)加载torchvision.models中已经训练好的网络模型、权重参数。
model_name = 'resnet'  		# 加载网络模型:['resnet', 'alexnet', 'vgg', 'squeezenet', 'densenet', 'inception']
feature_extract = True 		# 提取权重参数
classes_num = 102			# 设置分类数目(根据实际任务)
model_ft, input_size = initialize_model(model_name, classes_num, feature_extract=True, use_pretrained=True)		# 模型初始化
model_ft = model_ft.to(device)							# 模型加载到CPU/GPU中
#################################################################################################################
# (6)迁移学习第一步:模型训练(只训练输出层) ———— 即冻住全连接层之前的权重参数,只更新全连接层的权重参数。
params_to_update = model_ft.parameters()				# 提取初始化模型的权重参数
if feature_extract == True:
	params_to_update = []
	# model.named_parameters():返回每一层网络的名称和参数内容(权重和偏置)
	for name, param in model_ft.named_parameters():
		if param.requires_grad == True:
			params_to_update.append(param)				# 提取每一层网络的权重参数(已经训练好的权重参数)
			print("\t", name)
else:
	for name, param in model_ft.named_parameters():
		if param.requires_grad == True:					# 重新训练权重参数
			print("\t", name)

optimizer_ft = torch.optim.Adam(params_to_update, lr=1e-2)								# 优化器
scheduler = torch.optim.lr_scheduler.StepLR(optimizer_ft, step_size=10, gamma=0.1)		# 学习率。每7个epoch,衰减为原来的1/10
criterion = torch.nn.CrossEntropyLoss()													# 损失函数
num_epochs = 1
filename = 'checkpoint.pth'			# 自定义模型保存后的名字
model_ft, val_acc_history, train_acc_history, valid_losses, train_losses, lr_s = train_model(model_ft, dataloaders, optimizer_ft, criterion, num_epochs, filename)
#################################################################################################################
# (7)迁移学习第二步:模型训练(其余网络层+全连接层) ———— 即更新权重参数,最后得到当前网络模型的权重参数。
for param in model_ft.parameters():
	param.requires_grad = True		# 将所有的权重参数都设置为True(需要更新)

optimizer = torch.optim.Adam(model_ft.parameters(), lr=1e-3)							# 优化器(学习率可以小一点)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)		# 学习率。每7个epoch,衰减为原来的1/10
criterion = torch.nn.CrossEntropyLoss()													# 损失函数

checkpoint = torch.load(filename)						# 加载已经训练好的权重参数
best_acc = checkpoint['best_acc']						# 提取该模型的最优准确度
model_ft.load_state_dict(checkpoint['state_dict'])		# 加载该模型的权重参数
num_epochs = 1
model_ft, val_acc_history, train_acc_history, valid_losses, train_losses, lr_s = train_model(model_ft, dataloaders, optimizer, criterion, num_epochs, filename)
#################################################################################################################
# (8)模型验证
dataiter = iter(dataloaders['valid'])		# 提取验证集
images, labels = dataiter.next()			# 提取图像与标签(每次提取数据大小:一个epoch)

model_ft.eval()			# 模型验证
if train_on_gpu:							# 提取输出结果(CPU与GPU两种方法)
	output = model_ft(images.cuda())
else:
	output = model_ft(images)
_, preds_tensor = torch.max(output, 1)		# 提取预测结果矩阵(取概率最大值)
preds = np.squeeze(preds_tensor.numpy()) if not train_on_gpu else np.squeeze(preds_tensor.cpu().numpy())
# 格式转换:将tensor转换成numpy,包括CPU与GPU两种转换方法。
##########################################
# (9)画图(局部结果展示)
with open('cat_to_name.json', 'r') as f:
	cat_to_name = json.load(f)				# json.load():读取文件句柄

fig = plt.figure(figsize=(20, 20))
columns = 5		# 画图的行数
rows = 4		# 画图的列数
for idx in range(columns * rows):
	ax = fig.add_subplot(rows, columns, idx+1, xticks=[], yticks=[])		# 在画布上进行子区域划分。
	plt.imshow(im_convert(images[idx]))										# 画图。图像格式转换(tensor to numpy)并进行计算
	# 标题展示预测结果:前一个值为label,后一个值为预测值。若预测值为真,则为绿色,否则为红色。
	ax.set_title("label={} (pred={})" .format(cat_to_name[str(preds[idx])], cat_to_name[str(labels[idx].item())]),
														color=("green" if cat_to_name[str(preds[idx])] == cat_to_name[str(labels[idx].item())] else "red"))
plt.show()

(五)实战:基于迁移学习的102种花分类(自定义Dataset)

链接:https://pan.baidu.com/s/1nzV0_PorIupFVXlePoTzsw?pwd=ni9i
提取码:ni9i

train标签格式如下:
在这里插入图片描述

在这里插入图片描述

import os
import matplotlib.pyplot as plt
import numpy as np
import torch
from torchvision import transforms
from PIL import Image
from torch.utils.data import Dataset, DataLoader

os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'     # "OMP: Error #15: Initializing libiomp5md.dll"
##########################################################################


# (1)自定义Dataset
class FlowerDataset(Dataset):			# 继承torch.utils.data.Dataset
    def __init__(self, root_dir, ann_file, transform=None):
        """函数功能:参数初始化"""
        self.root_dir = root_dir        # 获取数据的根目录路径
        self.ann_file = ann_file        # 获取数据文件(.txt)
        self.transform = transform      # 图像预处理

        self.img_label = self.load_annotations()
        # 加载(.txt)文件,并获得图像名称和对应的标签
        self.img = [os.path.join(self.root_dir, img) for img in list(self.img_label.keys())]
        # (1)将【图像名称】数据转换为list;(2)遍历所有图像名称;(3)通过路径拼接,得到每张图像的存放地址;(4)并将其添加到系统路径中
        self.label = [label for label in list(self.img_label.values())]
        # (1)将【标签】数据转换为list;(2)遍历所有标签;

    def __len__(self):
        """函数功能:获取数据集大小"""
        return len(self.img)

    def __getitem__(self, idx):     # 默认自动打乱数据(idx:表示索引值)
        """函数功能:获取图像和标签"""
        image = Image.open(self.img[idx])       # 获取图像索引,并打开其所在路径,得到图像数据。
        label = self.label[idx]                 # 获取图像对应的标签值
        if self.transform:                      # 判断是否需要图像预处理
            image = self.transform(image)       # 如是,则执行图像预处理
        label = torch.from_numpy(np.array(label))       # 格式转换:numpy转换为tensor
        return image, label

    def load_annotations(self):
        """函数功能:读取(.txt)文件,提取数据集"""
        """文件存放的数据格式:每一行对应一个数据,格式为 = name label"""
        data_infos = {}     # 数据储存。字典结构 = {key=name, value=label}
        with open(self.ann_file) as f:
            # (1)逐行读取;(2)并以“ 空格符 ”进行分割;(3)然后保存为列表
            samples = [x.strip().split(' ') for x in f.readlines()]
            for filename, get_label in samples:
                data_infos[filename] = np.array(get_label, dtype=np.int64)       # np.array: 列表转换为ndarray
        return data_infos
        ##############################################################
        # strip():用于移除字符串开头或结尾指定的字符或字符串(默认为空格或换行符)。
        #       str = "00000003210Runoob01230000000";
        #       print str.strip( '0' );  # 去除首尾字符 0
        ##############################################################


# (2)图像预处理(图像预处理操作都是在DataLoader中完成的)
data_transforms = {
    'train': transforms.Compose([
        transforms.Resize([96, 96]),				# Resize的作用是对图像进行缩放。
        transforms.RandomRotation(45),				# RandomRotation的作用是对图像进行随机旋转。
        transforms.CenterCrop(64),					# CenterCrop的作用是从图像的中心位置裁剪指定大小的图像
        transforms.RandomHorizontalFlip(p=0.5),		# RandomHorizontalFlip的作用是以一定的概率对图像进行水平翻转。
        transforms.RandomVerticalFlip(p=0.5),		# RandomVerticalFlip的作用是以一定的概率对图像进行垂直翻转。
        transforms.ColorJitter(brightness=0.2, contrast=0.1, saturation=0.1, hue=0.1),		# ColorJitter的作用是随机修改图片的亮度、对比度和饱和度
        transforms.RandomGrayscale(p=0.025),		# RandomGrayscale的作用是以一定的概率将图像变为灰度图像。
        transforms.ToTensor(),						# 将PIL Image或numpy.ndarray转为pytorch的Tensor,并会将像素值由[0, 255]变为[0, 1]之间。
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),				# Normalize的作用是用均值和标准差对Tensor进行归一化处理。
    'valid': transforms.Compose([
        transforms.Resize([64, 64]),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
}

# (3)自定义图像数据路径
data_dir = './flower_data/'
train_dir = data_dir + '/train_filelist'
valid_dir = data_dir + '/val_filelist'

# (4)实例化dataloader
train_dataset = FlowerDataset(root_dir=train_dir, ann_file='./flower_data/train.txt', transform=data_transforms['train'])
val_dataset = FlowerDataset(root_dir=valid_dir, ann_file='./flower_data/val.txt', transform=data_transforms['valid'])

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=True)

# (5)模型测试
image, label = iter(train_loader).next()        # 迭代一个batch数据,然后next获取下一个batch数据(系统定性写法)
sample = image[0].squeeze()                     # 维度压缩:1*3*64*64 -> 3*64*64
sample = sample.permute((1, 2, 0)).numpy()      # 维度变换:3*64*64 -> 64*64*3
sample *= [0.229, 0.224, 0.225]                 # 还原(标准化)预处理:均值
sample += [0.485, 0.456, 0.406]                 # 还原(标准化)预处理:标准差
plt.imshow(sample)                                  # 画图
plt.title('label={}'.format(label[0].numpy()))      # 标题打印标签
plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/186185.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

2023第三方应用苹果电脑磁盘读写工具Tuxera NTFS

今天&#xff0c;小编要来分享的是Mac下一款实用的NTFS读写软件——Tuxera NTFS&#xff0c;我们都知道OS X默认是不支持NTFS格式写入的&#xff0c;对于很多使用U盘或移动硬盘写操作的朋友来说非常的不便。而Tuxera NTFS很好的解决了这个问题。小子这次带来的是2023版本。 Tu…

kali入侵电脑

kali入侵电脑 注意&#xff1a;本文仅用于教学目的 1 kali制作exe控制电脑&#xff08;msfvenom&#xff09; kali是黑客常用的系统&#xff0c;里面集成了很多的攻击软件&#xff0c;这里我给大家演示一种使用kali制作.exe文件来控制自己电脑的方式。 msfvenom a Metasploit s…

Vue2 和Vue 3的区别

Vue 2 和 Vue 3的区别 1.双向数据绑定原理不同 Vue2 的双向数据绑定是利用ES5的一个APIObject.definePropert() 对数据进行劫持&#xff0c;结合发布订阅模式的方式来实现的。 Vue3 中使用ES6的Proxy API对数据代理。 Vue3 使用数据代理的优势有以下几点&#xff1a;1&#x…

深圳MES系统如何助力注塑企业实现数字化发展

家用电器、电子产品、日用品、医疗保健、汽车零部件、新能源以及建筑、玩具等行业对注塑制品需求量日益增长。注塑企业提供的各式各样注塑产品已深入到经济生活的各个领域&#xff0c;为国家经济的各个部门包括轻工业和重工业提供关键的支持。 现状 现在注塑企业的注塑机工作…

RTP协议基本分析(RTSP、WebRTC使用)

1、介绍 实时流传输协议&#xff08;RTSP&#xff1a;Real Time Streaming Protocol&#xff09;是一种网络传输协议&#xff0c;旨在发送低延迟流。 该协议由RealNetworks&#xff0c;Netscape和哥伦比亚大学的专家在1996年开发。它定义了应如何打包流中的数 据以进行传输。 …

【GD32F427开发板试用】INA226完成电流电压采集

本篇文章来自极术社区与兆易创新组织的GD32F427开发板评测活动&#xff0c;更多开发板试用活动请关注极术社区网站。作者&#xff1a;จุ๊บ冰语 前言 本次有幸参与并通过了极术社区组织的【GD32F427开发板试用】活动&#xff0c;让我对国产兆易创新的GD32处理器有了更深刻的…

Scala系列之:函数式编程

Scala系列之&#xff1a;函数式编程一、面向对象编程和函数式编程二、函数基本语法三、函数和方法的区别四、函数定义五、函数参数六、函数至简原则七、匿名函数一、面向对象编程和函数式编程 面向对象编程&#xff1a; 解决问题&#xff0c;分解对象&#xff0c;行为&#x…

【最新消息】苹果放出新大招??!!

各位开发者新年快乐&#xff0c;许久没有更新了&#xff0c;近期我收到反馈意思遇到苹果回复的新政策&#xff0c;不知道各位开发者有没有碰到过&#xff0c;我也会在下文提出我的猜测&#xff0c;要是有开发者也遇到了同样的问题&#xff0c;欢迎一起交流哦。 疑似新政策&…

【寒假每日一题】洛谷 P1088 [NOIP2004 普及组] 火星人

题目链接&#xff1a;P1088 [NOIP2004 普及组] 火星人 - 洛谷 | 计算机科学教育新生态 (luogu.com.cn) 题目描述 人类终于登上了火星的土地并且见到了神秘的火星人。人类和火星人都无法理解对方的语言&#xff0c;但是我们的科学家发明了一种用数字交流的方法。这种交流方法是…

vue-query 初探

vue-query&#xff0c;类似于vuex/pinia&#xff0c;以缓存为目的&#xff0c;但侧重的是对网络请求的缓存。 这是我预想的使用场景&#xff1a;假设在各个页面都需要发起相同的请求&#xff0c;去获取数据&#xff0c;而这种数据在一定时间内不会发生变化&#xff0c;那么这种…

【JavaWeb】带你走进Maven

文章目录1 什么是Maven?2 Maven 常用命令3 Maven 生命周期4 Maven 坐标详解5 IDEA 导入 Maven 项目1 什么是Maven? 如今我们构建一个项目需要用到很多第三方的类库&#xff0c;如写一个使用Spring的Web项目就需要引入大量的jar包。一个项目Jar包的数量之多往往让我们瞠目结舌…

线程execute()与submit()区别

线程池中有两个提交任务的方法 向线程池提交任务的两种方式大致如下&#xff1a; 方式一&#xff1a;调用execute()方法 方式二&#xff1a;调用submit()方法 一、区别 以上的submit()和execute()两类方法的区别在哪里呢&#xff1f;大致有以下三点&#xff1a; 1.二者所接收…

引入“ 自动化测试 ”都需要满足哪些条件?

&#x1f4cc; 博客主页&#xff1a; 自动化软件测试 &#x1f4cc; 专注于软件测试领域相关技术实践和思考&#xff0c;持续分享自动化软件测试开发干货知识&#xff01; &#x1f4cc; 如果你也想学习软件测试&#xff0c;文末卡片有我的交流群&#xff0c;加入我们&#xff…

由浅入深,聊聊 LeakCanary 的那些事

引言 关于内存泄漏&#xff0c;Android 开发的小伙伴应该都再熟悉不过了&#xff0c;比如最常见的静态类间接持有了某个 Activity 对象&#xff0c;又比如某个组件库的订阅在页面销毁时没有及时清理等等&#xff0c;这些情况下多数时都会造成内存泄漏&#xff0c;从而对我们Ap…

linux内核-内存管理

linux内核内存管理 注意&#xff01;内核空间和用户空间都是处于虚拟空间中 Linux的虚拟地址空间范围为0&#xff5e;4G&#xff0c;Linux内核将这4G字节的空间分为两部分 内核空间&#xff1a; 最高的1G字节&#xff08;从虚拟地址0xC0000000到0xFFFFFFFF&#xff09;&…

RTSP,RTP,RTCP协议

一 RTSP 1 简介 实时流传输协议&#xff0c;是一个应用层协议&#xff08;TCP/IP网络体系中&#xff09;&#xff0c;它是一个多媒体播放控制协议&#xff0c;主要用来使用户在播放流媒体时可以像操作本地的影碟机一样进行控制&#xff0c;即可以对流媒体进行暂停/继续、后退和…

SAP FICO 关于资产的详细解析

SAP资产模块概述 一、概述 资产&#xff08;AA&#xff09;模块是资产会计模块的简称&#xff0c;是财务会计&#xff08;FI&#xff09;模块的一个子模块&#xff0c;主要处理与各类长期资产相关业务的模块。不单指固定资产&#xff0c;也不泛指资产负债表中的资产&#xff0c…

Week4

1.试题 历届真题 时间显示【第十二届】【省赛】【B组】 思路 不难发现,应该从小时往秒处理,这样可以用O(1)的时间复杂度求出,不过有比较麻烦的进位处理。 先看里面可以拼成几个小时,然后得到的小时%24,然后把总时间减去小时的时间,再看有多少分钟,分钟%60,都是此时判断分…

vue多环境配置之 .env配置文件

Vue之.env环境配置文件 .env文件是运行项目时的环境配置文件。但是在实际开发过程中&#xff0c;有本地环境、测试环境、预生产、生产环境等等&#xff0c;不同环境对应的配置会不一样。因此&#xff0c;需要通过不同的.env文件实现差异化配置。 * 文章目录Vue之.env环境配置文…

【JAVA核心知识】46:什么是零拷贝Zero-copy

零拷贝相较于传统的IO流程拥有更高的数据发送效率&#xff0c;无论是RocketMq,Kafka还是Netty等都用到了零拷贝技术&#xff0c;那究竟什么是零拷贝呢&#xff0c;零拷贝又是通过什么方式提升数据发送效率呢&#xff1f; 首先我们要明白&#xff0c;一次数据发送过程就是将磁盘…