【机器学习】 Flux.jl 求解 XOR 分类问题的神经网络模型

news2024/9/26 20:18:21

Flux.jl 搭建神经网络基本流程

Chain(Dense, BatchNorm, Dense)
DataLoader
setup( Adam, RMSProp, Momentum...)
数据准备
搭建多层感知器
建立优化问题
选择算法训练神经网络
输出结果
using Flux, Statistics

# 生成XOR问题的数据
noisy = rand(Float32, 2, 200)  # 2×200的矩阵
truth = [xor(col[1] > 0.5, col[2] > 0.5) for col in eachcol(noisy)]  # 200元素向量
target = Flux.onehotbatch(truth, [true, false])  # 2×200的OneHotMatrix

# 定义模型,一个具有3个隐藏层的多层感知器
model = Chain(
    Dense(2 => 3, tanh),  # 使用tanh激活函数
    BatchNorm(3),         # 批量归一化
    Dense(3 => 2)         # 输出层
)

# 模型输出
out1 = model(noisy)
probs1 = softmax(out1)  # 使用softmax函数获取概率

# 为训练准备目标数据


# 创建数据加载器
loader = Flux.DataLoader((noisy, target), batchsize=64, shuffle=true)

# 设置优化器
optim = Flux.setup(Flux.Adam(0.01), model)  # Adam 策略随机梯度方法

# 训练循环,遍历整个数据集1000次
losses = []
for epoch in 1:1000
    for (x, y) in loader
        loss, grads = Flux.withgradient(model) do m
            y_hat = m(x)
            Flux.logitcrossentropy(y_hat, y)
        end
        Flux.update!(optim, model, grads[1])
        push!(losses, loss)
    end
end

# 训练后的模型输出
out2 = model(noisy)
probs2 = softmax(out2)

# 计算准确率
accuracy = mean((probs2[1,:] .> 0.5) .== truth)
println("Accuracy: $(accuracy * 100)%")

using Plots  # to draw the above figure

p_true = scatter(noisy[1,:], noisy[2,:], zcolor=truth, title="True classification", legend=false)
p_raw =  scatter(noisy[1,:], noisy[2,:], zcolor=probs1[1,:], title="Untrained network", label="", clims=(0,1))
p_done = scatter(noisy[1,:], noisy[2,:], zcolor=probs2[1,:], title="Trained network", legend=false)

plot(p_true, p_raw,layout=(1,3), size=(200,330))

输出分类效果
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2167884.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【前端】35道JavaScript进阶问题

来源: javascript-questions/zh-CN/README-zh_CN.md at master lydiahallie/javascript-questions GitHub 记录一些有趣的题。 1 输出是? const shape {radius: 10,diameter() {return this.radius * 2},perimeter: () > 2 * Math.PI * this.rad…

[单master节点k8s部署]26.Istio流量管理(二)

bookinfo微服务 这个bookinfo微服务由四个微服务构成: 1)productpage 这个微服务会调用 details 和 reviews 两个微服务,用来生成页面; 2)details 这个微服务中包含了书籍的信息; 3)reviews …

Scikit-LearnTensorFlow机器学习实用指南(三):一个完整的机器学习项目【下】

机器学习实用指南(三):一个完整的机器学习项目【下】 作者:LeonG 本文参考自:《Hands-On Machine Learning with Scikit-Learn & TensorFlow 机器学习实用指南》,感谢中文AI社区ApacheCN提供翻译。 本文全部代码和数据集保存在…

TypeError: load() missing 1 required positional argument: ‘Loader‘

标题TypeError: load() missing 1 required positional argument: ‘Loader’ 源码: 处理后: 顺利通过,由于yaml版本导致的问题

Alertmanager 路由匹配

Alertmanager主要负责对Prometheus产生的告警进行统一处理,因此在Alertmanager配置中一般会包含以下几个主要部分: 全局配置(global):用于定义一些全局的公共参数,如全局的SMTP配置,Slack配置等…

day-61 外观数列

思路 每次对字符串进行遍历即可,用一个Integer统计相邻的相同字符个数,如果当前字符与后面邻接的字符相同,num;如果不同,则将""nums.charAt(j)拼接到字符串中 解题过程 当n1时,可以直接返回,不为…

【机器学习导引】ch3-线性模型

线性回归 梯度 在数学中,对于函数 f ( x 1 , … , x m ) f(x_1, \ldots, x_m) f(x1​,…,xm​) 在点 a ( a 1 , … , a m ) a (a_1, \ldots, a_m) a(a1​,…,am​) 处的梯度被定义为: ∇ f ( a ) ( ∂ f ∂ x 1 ( a ) , … , ∂ f ∂ x m ( a ) )…

排序题目:对角线遍历 II

文章目录 题目标题和出处难度题目描述要求示例数据范围 解法思路和算法代码复杂度分析 题目 标题和出处 标题:对角线遍历 II 出处:1424. 对角线遍历 II 难度 6 级 题目描述 要求 给定一个二维整数数组 nums \texttt{nums} nums,将 …

vue3.0 + element plus 全局自定义指令:select滚动分页

需求:项目里面下拉框数据较多 ,一次性请求数据,体验差,效果就是滚动进行分页。 看到这个需求的时候,我第一反应就是封装成自定义指令,这样回头用的时候,直接调用就可以了。 第一步 第二步&…

eHR软件的价格一般是多少?

在人力资源数字化转型的大潮中,越来越多的企业开始关注eHR(电子人力资源管理)软件的采购问题。eHR软件价格并不是一个简单的数字,而是受多种因素影响,具有较大波动性。那么,eHR软件的价格一般是多少呢&…

侧边菜单的展开和折叠

通过按钮控制侧边栏的展开和折叠通过窗口宽度的变化控制侧边栏的展开和折叠&#xff08;小于768px,自动折叠&#xff09; 通过按钮控制展开 通过按钮控制折叠 切换到手机模式自动折叠 环境准备&#xff1a;Vue3Element-UI Plus <script setup> import {onMounted, r…

基于SpringBoot + Vue的Gucci进销存系统

文章目录 前言一、详细操作演示视频二、具体实现截图三、技术栈1.前端-Vue.js2.后端-SpringBoot3.数据库-MySQL4.系统架构-B/S 四、系统测试1.系统测试概述2.系统功能测试3.系统测试结论 五、项目代码参考六、数据库代码参考七、项目论文示例结语 前言 &#x1f49b;博主介绍&a…

001. OBS (obs-studio)

1. 下载 https://obsproject.com/download windows c 插件下载 https://obsproject.com/visual-studio-2022-runtimes 2. 操作步骤 https://renwen.shnu.edu.cn/_s40/9a/2c/c28309a760364/page.psp https://zhuanlan.zhihu.com/p/597231652

【Java 问题】基础——IO

接上文 IO 42.Java 中 IO 流分为几种?Java IO体系中的装饰器模式抽象组件&#xff08;Component&#xff09;具体组件&#xff08;Concrete Component&#xff09;抽象装饰器&#xff08;Decorator&#xff09;具体装饰器&#xff08;Concrete Decorator&#xff09;使用装饰器…

喜讯 | 宝兰德「应用服务器软件 V9.5」荣获“2024年度优秀软件产品”殊荣

近日&#xff0c;中国软件行业协会公布了“2024年度推广优秀软件产品”名单。经过专家委员会的评议及最终审核&#xff0c;宝兰德凭借领先的技术能力和丰富的经验积累&#xff0c;中间件核心产品「应用服务器软件 V9.5」获评“2024年度优秀软件产品”。 本次评选活动由中国软件…

基于SpringBoot的在线考试系统设计与实现

1.1 研究背景 21世纪&#xff0c;我国早在上世纪就已普及互联网信息&#xff0c;互联网对人们生活中带来了无限的便利。像大部分的企事业单位都有自己的系统&#xff0c;由从今传统的管理模式向互联网发展&#xff0c;如今开发自己的系统是理所当然的。那么开发在线考试系统意…

vscode【实用插件】Project Manager 项目管理

安装 在 vscode 插件市场的搜索 Project Manager点 安装 安装成功后&#xff0c;vscode 左侧栏会出现 使用 将项目添加到项目列表中 用 vscode 打开项目&#xff0c;点保存即可 将项目移出项目列表 切换项目 单击项目列表中的项目&#xff0c;即可切换到目标项目 新窗口打开…

道一云·七巧和金蝶云星空单据接口对接

道一云七巧和金蝶云星空单据接口对接 对接系统金蝶云星空 金蝶K/3Cloud结合当今先进管理理论和数十万家国内客户最佳应用实践&#xff0c;面向事业部制、多地点、多工厂等运营协同与管控型企业及集团公司&#xff0c;提供一个通用的ERP服务平台。K/3Cloud支持的协同应用包括但不…

淘宝霸屏必备工具:淘宝商品评论电商API接口

淘宝商品评论电商API接口是指用于获取淘宝商品评论信息的一种接口&#xff0c;通过该接口可以获取淘宝网上商品的评价内容、评价等级、评价数量等信息。通过了解并使用该接口&#xff0c;能够帮助电商了解消费者对商品的评价情况&#xff0c;做好商品的推广和销售工作。 接口使…

电脑提速秘籍:6款不可不知的Windows实用软件

6款Windows系统上不可或缺的高效工具&#xff0c;每一款都是小巧而强大的存在&#xff0c;让你的电脑使用更加流畅&#xff01; 1.unlocker 当你遇到那些顽固的文件&#xff0c;需要管理员权限或者重启电脑才能删除时&#xff0c;这款只有1.02MB的轻量级工具可以帮你轻松解决问…