PyTorch深度学习网络(一:MLP)

news2024/11/15 8:40:38

全连接神经网络,又称多层感知机(MLP),是深度学习最基础的神经网络。全连接神经网络主要由输入层、隐藏层和输出层构成。本文实现了一个通用MLP网络,包括以下功能:

  1. 根据输入的特征数、类别数、各隐藏层神经元数量构建一个MLP网络;
  2. 可以指定隐藏层的激活函数(默认为F.relu);
  3. 可以指定输出层的激活函数(默认回归无激活函数,分类激活函数为F.softmax)。

代码如下:

from typing import Optional

import numpy as np
from sklearn.preprocessing import StandardScaler, MinMaxScaler
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import SGD, Adam

from process import classify, regress
#procss的代码见:https://blog.csdn.net/moyao_miao/article/details/141466047
#               https://blog.csdn.net/moyao_miao/article/details/141497342


class MLP(nn.Module):
    """通用MLP网络
    """

    def __init__(self, feature_num: int, class_num: int, *hidden_nums: int,
                 fc_activation: nn.Module = F.relu, output_activation: Optional[nn.Module] = None):
        """
        初始化MLP网络
        :param feature_num: 输入特征数
        :param class_num: 输出类别数
        :param hidden_nums: 隐藏层神经元数
        :param fc_activation: 隐藏层激活函数,默认为F.relu
        :param output_activation: 输出层激活函数,默认回归无激活函数,分类激活函数为F.softmax
        """
        super().__init__()
        self.feature_num = feature_num
        self.class_num = class_num
        self.hidden_nums = hidden_nums
        self.fc_activation = fc_activation
        self.output_activation = output_activation
        input_num = feature_num
        # 定义隐藏层
        for i, hidden_num in enumerate(hidden_nums):
            self.__dict__['_modules']['fc' + str(i)] = nn.Linear(input_num, hidden_num)
            input_num = hidden_num
        self.output = nn.Linear(input_num, class_num)

    def forward(self, x):
        # 定义网络的向前传播路径
        for i in range(len(self.hidden_nums)):
            x = self.fc_activation(self.__dict__['_modules']['fc' + str(i)](x))
        if self.output_activation is not None:
            x = self.output_activation(self.output(x))
        else:
            x = self.output(x)[..., 0] if self.class_num == 1 else F.softmax(self.output(x), dim=-1)
        return x

关于隐藏层定义的详细说明见:【求助帖(已解决)】用PyTorch搭建MLP网络时遇到奇怪的问题-CSDN博客

 下面举两个例子测试一下效果:

一、垃圾邮件分类

    from ucimlrepo import fetch_ucirepo

    spambase = fetch_ucirepo(id=94)
    X = np.array(spambase.data.features)
    y = np.array(spambase.data.targets.iloc[:, 0])
    model = MLP(57, 2, 30, 10)
    optimizer = Adam(model.parameters(), lr=0.01)
    criterion = nn.CrossEntropyLoss()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    classify(
        (X, y),
        model,
        optimizer,
        criterion,
        scaler=MinMaxScaler(feature_range=(0, 1)),
        batch_size=64,
        epochs=10,
        device=device,
    )

分类效果:

二、波士顿房价预测

    from sklearn.datasets import fetch_california_housing

    house_data = fetch_california_housing()
    model = MLP(8, 1, 100, 100, 50)
    optimizer = SGD(model.parameters(), lr=0.01)
    criterion = nn.MSELoss()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    regress(
        (house_data.data, house_data.target),
        model,
        optimizer,
        criterion,
        scaler=StandardScaler(),
        batch_size=64,
        epochs=30,
        device=device,
    )

 预测效果:

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2074626.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Canvas 动画: atan2 三角函数与鼠标跟随效果

这个案例展示了如何使用HTML5的Canvas和JavaScript实现一个动态效果:在画布上绘制一个箭头,并让它实时跟随鼠标移动。这个小项目不仅有趣,还能帮助你理解编程和基本数学概念的实际应用。 项目需求 我们的目标是在一个画布上绘制一个箭头&…

java-4 final、单例类、枚举类、抽象类、接口

final 1. 认识final 2. 常量 大项目中经常将常量集中写在Constant文件中 单例类 (设计模式) 为什么要把构造器私有化,你不是私有化,别人就可以 new 好多个对象,还怎么是单例吖 定义一个类变量、类方法,…

海外媒体宣发:著名媒体【越南通讯社VNanet】发布新闻稿

海外媒体宣发:著名媒体【越南通讯社VNanet】发布新闻稿 近日,越南通讯社VNanet发布了一篇关于全球气候变化的新闻稿,引起了广泛关注。本文将详细介绍新闻稿的主要内容以及其对全球气候变化的影响。 一、新闻稿概述 越南通讯社VNanet作为越…

解决WIndows10下更新蓝牙驱动屡屡失败问题

因为换了个1T自带Win10系统的SSD硬盘,导致蓝牙驱动死活装不上了。驱动精灵,官方驱动都没用。这可前所未闻啊。 想起换下来的硬盘系统里面还有系统在,试试看能不能直接用之前的系统蓝牙驱动,原则上是应该没问题的。所以就将之前的…

混合现实UI优化:利用物理环境的直接交互

随着虚拟现实(VR)和混合现实(MR)技术的发展,用户界面(UI)的设计变得越来越重要,尤其是在需要适应多种物理环境的情况下。本文将介绍一种名为 InteractionAdapt 的用户界面优化方法,它专为VR环境中的工作空间适配而设计,能够有效利用物理环境,为用户提供更加灵活和个…

Kafka的Offset(偏移量)详解

Kafka的Offset详解 1、生产者Offset2、消费者Offset2.1、消费者2.2、生产者2.3、实体类对象2.4、JSON工具类2.5、项目配置文件2.6、测试类2.7、测试2.8、总结 1、生产者Offset 2、消费者Offset 2.1、消费者 package com.power.consumer;import org.apache.kafka.clients.consu…

Android - 自定义view

为什么要自定义view? 在Android开发中有很多业务场景,原生的控件无法满足需求,并且经常也会遇到一个UI在多处重复使用情况,于是可以通过自定义View的方式来实现这些UI效果。 自定义view的分类 自定义属性 Window window是一个…

图数据库查询语言 cypher 与 memgraph

Cyper 作为声明式查询语言, SQL 在计算机行业无人不晓, 无人不知. 而 Cypher 就是 Graph Database 图数据库的 SQL. Cypher 用"圆括号"来表示节点, 用"方括号,连接线及箭头"表示关系 这样一句话 - "Sally likes Graphs. Sally is friends with John. …

完成控制器方法获取参数-@RequestParam

文章目录 1.将方法的request和response参数封装到参数数组1.SunDispatcherServlet.java1.根据方法信息,返回实参列表2.具体调用 2.测试 2.封装Http请求参数到参数数组1.自定义RequestParam注解2.MonsterController.java 增加参数3.SunDispatcherServlet.java1.resol…

软件架构的发展经历了从单体结构、垂直架构、SOA架构到微服务架构的过程剖析

1.单体架构 特点: 1、所有的功能集成在一个项目工程中。 2、所有的功能打一个war包部署到服务器。 3、应用与数据库分开部署。 4、通过部署应用集群和数据库集群来提高系统的性能。 优点: 1、项目架构简单,前期开发成本低,周期短,小型项目的首选。 缺点: 1、全部…

c++实现mysql关系型数据库连接与增删改查操作

最近老师让我实现这个功能,顺便发个东西,我感觉mysql从入门到精通这本书写的蛮好的,其实连接数据库就是调用mysql-c-api库里面的函数mysql_real_connect,下来的增删改查,也无非就是cmd命令台里面的语句,插入&#xff1…

Javaweb学习之Vue实践小界面(四)

目录 前情回顾 本期介绍 效果图 第一步:前期工作 第二步:建立页眉 效果图 第三步:建立导航栏 效果图 第四步:主要内容放置 效果图 第五步:建立页脚 效果图 综合:将文字和背景更改成自己喜欢的…

PEM燃料电池启停控制策略优化的simulink建模与仿真

目录 1.课题概述 2.系统仿真结果 3.核心程序与模型 4.系统原理简介 5.完整工程文件 1.课题概述 PEM燃料电池启停控制策略优化的simulink建模与仿真。 1.燃料电池提供是燃料转换为电能和热能的装置。 2.功率的输出的改变通过很多因素,如温度,压力…

谷歌、火狐及Edge等浏览器如何使用allWebPlugin中间件响应ActiveX插件事件

allWebPlugin简介 allWebPlugin中间件是一款为用户提供安全、可靠、便捷的浏览器插件服务的中间件产品,致力于将浏览器插件重新应用到所有浏览器。它将现有ActiveX控件直接嵌入浏览器,实现插件加载、界面显示、接口调用、事件回调等。支持Chrome、Firefo…

模型 OGSM(战略规划)

系列文章 分享 模型,了解更多👉 模型_思维模型目录。目标引领,策略驱动,量化衡量。 1 OGSM模型的应用 1.1 电商企业年度增长战略 某电商企业面临激烈的市场竞争,决定运用OGSM模型来规划其年度战略,以实现…

代码随想录Day 25|回溯篇完结,题目:491.递增子序列、46、全排列、47.全排列Ⅱ

提示:DDU,供自己复习使用。欢迎大家前来讨论~ 文章目录 第七章 回溯算法part05一、题目题目一:491.递增子序列解题思路:回溯三部曲优化 题目二:46.全排列[46. 全排列](https://leetcode.cn/problems/permutations/)解…

日撸Java三百行(day34:图的深度优先遍历)

目录 一、深度优先搜索 二、图的深度优先遍历 三、代码实现 总结 一、深度优先搜索 深度优先搜索(Depth First Search:DFS)是一种用于遍历树或图的算法,具体来说就是从起始节点开始,沿某一分支路径不断深入&#…

Linux内核定时器、阻塞_非阻塞IO

一.内核时间管理 Linux 内核中有大量的函数需要时间管理,比如周期性的调度程序、延时程序、对于我们驱动编写者来说最常用的定时器。硬件定时器提供时钟源,时钟源的频率可以设置, 设置好以后就周期性的产生定时中断,系统使用定时中断来计时。中断周期性产生的频率就是系统频率…

吴恩达谈AI未来:Agentic Workflow、推理成本下降与开源的优势

近年来,人工智能(AI)领域的发展势如破竹,然而随着技术的普及,市场也开始出现对AI泡沫的质疑声。2024年8月,AI领域的权威专家吴恩达(Andrew Ng)在与ARK Invest的对谈中,分…

利用Matlab求解高阶微分方程(ode45)

1、高阶微分方程的基本概念 二阶以及二阶以上的微分方程称之为高阶微分方程,一般来说,微分方程的阶数越高,求解的难度也就越大。求高阶方程的一个常用方法就是降低阶数。对二阶方程 ,如果能用变量代换把它化成一阶方程&#xff0c…