pytorch实现半监督学习

news2025/2/5 20:32:11

 人工智能例子汇总:AI常见的算法和例子-CSDN博客 

半监督学习(Semi-Supervised Learning,SSL)结合了有监督学习和无监督学习的特点,通常用于部分数据有标签、部分数据无标签的场景。其主要步骤如下:

1. 数据准备

  • 有标签数据(Labeled Data):数据集的一部分带有真实的类别标签。
  • 无标签数据(Unlabeled Data):数据集的另一部分没有标签,仅有特征信息。
  • 数据预处理:对数据进行清理、标准化、特征工程等处理,以保证数据质量。

2. 选择半监督学习方法

常见的半监督学习方法包括:

  • 基于生成模型(Generative Models):如高斯混合模型(GMM)、变分自编码器(VAE)。
  • 基于一致性正则化(Consistency Regularization):如 MixMatch、FixMatch,利用数据增强来约束模型预测一致性。
  • 基于伪标签(Pseudo-Labeling):先用模型预测无标签数据的类别,然后将高置信度的预测作为新标签加入训练。
  • 图神经网络(Graph-Based Methods):如 Label Propagation,通过构造数据之间的图结构传播标签信息。

3. 训练初始模型

  • 仅使用有标签数据训练一个初始模型。
  • 选择合适的损失函数,如交叉熵损失(Cross-Entropy Loss)或均方误差(MSE Loss)。
  • 训练过程中可以使用数据增强、正则化等优化策略。

4. 利用无标签数据增强训练

  • 伪标签方法:用初始模型对无标签数据进行预测,筛选高置信度样本,加入有标签数据训练。
  • 一致性正则化:对无标签数据进行不同变换,要求模型的预测结果一致。
  • 联合训练:构造有监督损失(Supervised Loss)和无监督损失(Unsupervised Loss),综合优化。

5. 模型迭代更新

  • 重新利用训练后的模型预测无标签数据,产生新的伪标签或调整模型参数。
  • 通过半监督策略不断优化模型,使其对无标签数据的预测更加稳定。

6. 评估和测试

  • 使用测试集(通常是有标签的数据)评估模型性能。
  • 选择合适的评估指标,如准确率(Accuracy)、F1-score、AUC-ROC 等。

7. 调优和部署

  • 根据实验结果调整超参数,如伪标签置信度阈值、学习率等。
  • 结合业务需求,将最终模型部署到实际应用中。

关键步骤:

  1. 初始化模型:首先使用有标签数据训练模型。
  2. 生成伪标签:用训练好的模型对无标签数据进行预测,生成伪标签。
  3. 结合有标签和伪标签数据进行训练:用带有标签和无标签(伪标签)数据一起训练模型。
  4. 迭代训练:不断迭代,使用更新的模型生成新的伪标签,进一步优化模型。
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import matplotlib.pyplot as plt


# 简化的神经网络模型
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 8, kernel_size=3)  # 缩小卷积层的输出通道
        self.fc1 = nn.Linear(8 * 26 * 26, 10)  # 调整全连接层的输入和输出尺寸

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = x.view(x.size(0), -1)  # 展平
        x = self.fc1(x)
        return x


# 自定义数据集
class CustomDataset(Dataset):
    def __init__(self, data, labels=None):
        self.data = data
        self.labels = labels

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        if self.labels is not None:
            return self.data[idx], self.labels[idx]
        else:
            return self.data[idx], -1  # 无标签数据


# 半监督训练函数
def pseudo_labeling_training(model, labeled_loader, unlabeled_loader, optimizer, device, threshold=0.95):
    model.train()
    labeled_loss_value = 0
    pseudo_loss_value = 0

    for (labeled_data, labeled_labels), (unlabeled_data, _) in zip(labeled_loader, unlabeled_loader):
        labeled_data, labeled_labels = labeled_data.to(device), labeled_labels.to(device)
        unlabeled_data = unlabeled_data.to(device)

        # 1. 有标签数据训练
        optimizer.zero_grad()
        labeled_output = model(labeled_data)
        labeled_loss = F.cross_entropy(labeled_output, labeled_labels)
        labeled_loss.backward()

        # 2. 无标签数据伪标签生成
        unlabeled_output = model(unlabeled_data)
        probs = F.softmax(unlabeled_output, dim=1)
        max_probs, pseudo_labels = torch.max(probs, dim=1)

        # 伪标签置信度筛选
        pseudo_mask = max_probs > threshold  # 置信度大于阈值的数据作为伪标签
        if pseudo_mask.sum() > 0:
            pseudo_labels = pseudo_labels[pseudo_mask]
            unlabeled_data_pseudo = unlabeled_data[pseudo_mask]

            # 3. 使用伪标签数据进行训练(确保无标签数据参与反向传播)
            optimizer.zero_grad()  # 清除之前的梯度
            pseudo_output = model(unlabeled_data_pseudo)
            pseudo_loss = F.cross_entropy(pseudo_output, pseudo_labels)
            pseudo_loss.backward()  # 计算反向梯度

        optimizer.step()  # 更新模型参数

        # 累加损失用于展示
        labeled_loss_value += labeled_loss.item()
        if pseudo_mask.sum() > 0:
            pseudo_loss_value += pseudo_loss.item()

    return labeled_loss_value / len(labeled_loader), pseudo_loss_value / len(unlabeled_loader)


# 模拟数据
num_labeled = 1000
num_unlabeled = 5000
data_dim = (1, 28, 28)  # 28x28 灰度图像
num_classes = 10

labeled_data = torch.randn(num_labeled, *data_dim)
labeled_labels = torch.randint(0, num_classes, (num_labeled,))
unlabeled_data = torch.randn(num_unlabeled, *data_dim)

labeled_dataset = CustomDataset(labeled_data, labeled_labels)
unlabeled_dataset = CustomDataset(unlabeled_data)

labeled_loader = DataLoader(labeled_dataset, batch_size=32, shuffle=True)  # 缩小批量大小
unlabeled_loader = DataLoader(unlabeled_dataset, batch_size=32, shuffle=True)  # 缩小批量大小

# 模型、优化器和设备设置
device = torch.device("cpu")  # 临时使用 CPU
model = SimpleCNN().to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 训练过程并记录损失
num_epochs = 10
labeled_losses = []
pseudo_losses = []

for epoch in range(num_epochs):
    labeled_loss, pseudo_loss = pseudo_labeling_training(model, labeled_loader, unlabeled_loader, optimizer, device)
    labeled_losses.append(labeled_loss)
    pseudo_losses.append(pseudo_loss)
    print(f"Epoch [{epoch + 1}/{num_epochs}] | Labeled Loss: {labeled_loss:.4f} | Pseudo Loss: {pseudo_loss:.4f}")

# 绘制损失曲线
plt.plot(range(num_epochs), labeled_losses, label='Labeled Loss')
plt.plot(range(num_epochs), pseudo_losses, label='Pseudo Label Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.title('Training Losses Over Epochs')
plt.show()

# 展示伪标签生成效果(可视化一些样本的伪标签预测结果)
model.eval()
with torch.no_grad():
    sample_unlabeled_data = unlabeled_data[:10].to(device)
    output = model(sample_unlabeled_data)
    probs = F.softmax(output, dim=1)
    _, predicted_labels = torch.max(probs, dim=1)

    # 展示预测的标签
    print("Generated Pseudo Labels for Samples:")
    print(predicted_labels)

    # 假设这些是伪标签预测的图片
    fig, axes = plt.subplots(2, 5, figsize=(12, 5))
    for i, ax in enumerate(axes.flat):
        # 将tensor转换为NumPy数组
        img = sample_unlabeled_data[i].cpu().numpy().squeeze()  # 转为NumPy数组
        ax.imshow(img, cmap='gray')  # 使用灰度显示图像
        ax.set_title(f"Pred: {predicted_labels[i].item()}")
        ax.axis('off')
    plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2293464.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Golang :用Redis构建高效灵活的应用程序

在当前的应用程序开发中,高效的数据存储和检索的必要性已经变得至关重要。Redis是一个快速的、开源的、内存中的数据结构存储,为各种应用场景提供了可靠的解决方案。在这个完整的指南中,我们将学习什么是Redis,通过Docker Compose…

deepseek+vscode自动化测试脚本生成

近几日Deepseek大火,我这里也尝试了一下,确实很强。而目前vscode的AI toolkit插件也已经集成了deepseek R1,这里就介绍下在vscode中利用deepseek帮助我们完成自动化测试脚本的实践分享 安装AI ToolKit并启用Deepseek 微软官方提供了一个针对AI辅助的插件,也就是 AI Toolk…

【大数据技术】Day07:本机DataGrip远程连接虚拟机MySQL/Hive

本机DataGrip远程连接虚拟机MySQL/Hive datagrip-2024.3.4VMware Workstation Pro 16CentOS-Stream-10-latest-x86_64-dvd1.iso写在前面 本文主要介绍如何使用本机的DataGrip连接虚拟机的MySQL数据库和Hive数据库,提高编程效率。 安装DataGrip 请按照以下步骤安装DataGrip软…

大语言模型的个性化综述 ——《Personalization of Large Language Models: A Survey》

摘要: 本文深入解读了论文“Personalization of Large Language Models: A Survey”,对大语言模型(LLMs)的个性化领域进行了全面剖析。通过详细阐述个性化的基础概念、分类体系、技术方法、评估指标以及应用实践,揭示了…

[论文学习]Adaptively Perturbed Mirror Descent for Learning in Games

[论文学习]Adaptively Perturbed Mirror Descent for Learning in Games 前言概述前置知识和问题约定单调博弈(monotone game)Nash均衡和Gap函数文章问题定义Mirror Descent 方法评价 前言 文章链接 我们称集合是紧的,则集合满足&#xff1…

【Unity踩坑】Unity项目管理员权限问题(Unity is running as administrator )

问题描述: 使用Unity Hub打开或新建项目时会有下面的提示。 解决方法: 打开“本地安全策略”: 在Windows搜索栏中输入secpol.msc并回车,或者从“运行”对话框(Win R,然后输入secpol.msc)启…

一文讲解Java中的ArrayList和LinkedList

ArrayList和LinkedList有什么区别? ArrayList 是基于数组实现的,LinkedList 是基于链表实现的。 二者用途有什么不同? 多数情况下,ArrayList更利于查找,LinkedList更利于增删 由于 ArrayList 是基于数组实现的&#…

使用 DeepSeek-R1 与 AnythingLLM 搭建本地知识库

一、下载地址Download Ollama on macOS 官方网站:Ollama 官方模型库:library 二、模型库搜索 deepseek r1 deepseek-r1:1.5b 私有化部署deepseek,模型库搜索 deepseek r1 运行cmd复制命令:ollama run deepseek-r1:1.5b 私有化…

MapReduce分区

目录 1. MapReduce分区1.1 哈希分区1.2 自定义分区 2. 成绩分组2.1 Map2.2 Partition2.3 Reduce 3. 代码和结果3.1 pom.xml中依赖配置3.2 工具类util3.3 GroupScores3.4 结果 参考 本文引用的Apache Hadoop源代码基于Apache许可证 2.0,详情请参阅 Apache许可证2.0。…

【Spring】Spring Cloud Alibaba 版本选择及项目搭建笔记

文章目录 前言1. 版本选择2. 集成 Nacos3. 服务间调用4. 集成 Sentinel5. 测试后记 前言 最近重新接触了 Spring Cloud 项目,为此参考多篇官方文档重新搭建一次项目,主要实践: 版本选择,包括 Spring Cloud Alibaba、Spring Clou…

C语言实现统计字符串中不同ASCII字符个数

在C语言编程中,经常会遇到一些对字符串进行处理的需求,今天我们就来探讨如何统计给定字符串中ASCII码在0 - 127范围内不同字符的个数。这不仅是一个常见的算法问题,也有助于我们更好地理解C语言中数组和字符操作的相关知识。 问题描述 对于给…

保姆级教程Docker部署Zookeeper官方镜像

目录 1、安装Docker及可视化工具 2、创建挂载目录 3、运行Zookeeper容器 4、Compose运行Zookeeper容器 5、查看Zookeeper运行状态 6、验证Zookeeper是否正常运行 1、安装Docker及可视化工具 Docker及可视化工具的安装可参考:Ubuntu上安装 Docker及可视化管理…

DeepSeek R1 简易指南:架构、本地部署和硬件要求

DeepSeek 团队近期发布的DeepSeek-R1技术论文展示了其在增强大语言模型推理能力方面的创新实践。该研究突破性地采用强化学习(Reinforcement Learning)作为核心训练范式,在不依赖大规模监督微调的前提下显著提升了模型的复杂问题求解能力。 技…

【Linux系统】信号:再谈OS与内核区、信号捕捉、重入函数与 volatile

再谈操作系统与内核区 1、浅谈虚拟机和操作系统映射于地址空间的作用 我们调用任何函数(无论是库函数还是系统调用),都是在各自进程的地址空间中执行的。无论操作系统如何切换进程,它都能确保访问同一个操作系统实例。换句话说&am…

自定义数据集 使用paddlepaddle框架实现逻辑回归

导入必要的库 import numpy as np import paddle import paddle.nn as nn 数据准备: seed1 paddle.seed(seed)# 1.散点输入 定义输入数据 data [[-0.5, 7.7], [1.8, 98.5], [0.9, 57.8], [0.4, 39.2], [-1.4, -15.7], [-1.4, -37.3], [-1.8, -49.1], [1.5, 75.6…

LabVIEW图片识别逆向建模系统

本文介绍了一个基于LabVIEW的图片识别逆向建模系统的开发过程。系统利用LabVIEW的强大视觉处理功能,通过二维图片快速生成对应的三维模型,不仅降低了逆向建模的技术门槛,还大幅提升了建模效率。 ​ 项目背景 在传统的逆向建模过程中&#xf…

MySQL(高级特性篇) 13 章——事务基础知识

一、数据库事务概述 事务是数据库区别于文件系统的重要特性之一 (1)存储引擎支持情况 SHOW ENGINES命令来查看当前MySQL支持的存储引擎都有哪些,以及这些存储引擎是否支持事务能看出在MySQL中,只有InnoDB是支持事务的 &#x…

前端进阶:深度剖析预解析机制

一、预解析是什么? 在前端开发中,我们常常会遇到一些看似不符合常规逻辑的代码执行现象,比如为什么在变量声明之前访问它,得到的结果是undefined,而不是报错?为什么函数在声明之前就可以被调用&#xff1f…

【力扣】53.最大子数组和

AC截图 题目 思路 这道题主要考虑的就是要排除负数带来的负面影响。如果遍历数组,那么应该有如下关系式: currentAns max(prenums[i],nums[i]) pre是之前记录的最大和,如果prenums[i]小于nums[i],就要考虑舍弃pre,从…

基于Spring Security 6的OAuth2 系列之七 - 授权服务器--自定义数据库客户端信息

之所以想写这一系列,是因为之前工作过程中使用Spring Security OAuth2搭建了网关和授权服务器,但当时基于spring-boot 2.3.x,其默认的Spring Security是5.3.x。之后新项目升级到了spring-boot 3.3.0,结果一看Spring Security也升级…